From 5f5e4aa404ee710bb5acca74c273fdf4de11a31d Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Tue, 17 Oct 2023 16:19:55 +0200 Subject: [PATCH 01/10] refactor(dslx): Runtime is a renamed, abstract ConnPool We want to use the DSL inside the oohelperd. We don't care about collecting observations in the oohelperd. So, the plan is that of abstracting the ConnPool, renaming it Runtime, and giving it the power to create abstract Traces. The oohelperd will use a MinimalRuntime (sketched out by this commit) that will not collect any observation. Measuring code, instead, will use a MeasurexRuntime that will collect observations. This commit is just the first step. We rename and introduce the MinimalRuntime. No significant functional changes so far. --- internal/dslx/connpool.go | 42 ------------------ internal/dslx/integration_test.go | 6 +-- internal/dslx/quic.go | 12 +++--- internal/dslx/quic_test.go | 18 ++++---- internal/dslx/runtimecore.go | 15 +++++++ internal/dslx/runtimeminimal.go | 43 +++++++++++++++++++ ...onnpool_test.go => runtimeminimal_test.go} | 31 ++++++------- internal/dslx/tcp.go | 8 ++-- internal/dslx/tcp_test.go | 10 ++--- internal/dslx/tls.go | 12 +++--- internal/dslx/tls_test.go | 18 ++++---- internal/tutorial/dslx/chapter02/README.md | 16 +++---- internal/tutorial/dslx/chapter02/main.go | 16 +++---- 13 files changed, 132 insertions(+), 115 deletions(-) delete mode 100644 internal/dslx/connpool.go create mode 100644 internal/dslx/runtimecore.go create mode 100644 internal/dslx/runtimeminimal.go rename internal/dslx/{connpool_test.go => runtimeminimal_test.go} (76%) diff --git a/internal/dslx/connpool.go b/internal/dslx/connpool.go deleted file mode 100644 index 0636147b02..0000000000 --- a/internal/dslx/connpool.go +++ /dev/null @@ -1,42 +0,0 @@ -package dslx - -// -// Connection pooling to streamline closing connections. -// - -import ( - "io" - "sync" -) - -// ConnPool tracks established connections. The zero value -// of this struct is ready to use. -type ConnPool struct { - mu sync.Mutex - v []io.Closer -} - -// MaybeTrack tracks the given connection if not nil. This -// method is safe for use by multiple goroutines. -func (p *ConnPool) MaybeTrack(c io.Closer) { - if c != nil { - defer p.mu.Unlock() - p.mu.Lock() - p.v = append(p.v, c) - } -} - -// Close closes all the tracked connections in reverse order. This -// method is safe for use by multiple goroutines. -func (p *ConnPool) Close() error { - // Implementation note: reverse order is such that we close TLS - // connections before we close the TCP connections they use. Hence - // we'll _gracefully_ close TLS connections. - defer p.mu.Unlock() - p.mu.Lock() - for idx := len(p.v) - 1; idx >= 0; idx-- { - _ = p.v[idx].Close() - } - p.v = nil // reset - return nil -} diff --git a/internal/dslx/integration_test.go b/internal/dslx/integration_test.go index df182e3af7..e37e0a40d0 100644 --- a/internal/dslx/integration_test.go +++ b/internal/dslx/integration_test.go @@ -32,12 +32,12 @@ func TestMakeSureWeCollectSpeedSamples(t *testing.T) { defer server.Close() // instantiate a connection pool - pool := &ConnPool{} - defer pool.Close() + rt := NewMinimalRuntime() + defer rt.Close() // create a measuring function f0 := Compose3( - TCPConnect(pool), + TCPConnect(rt), HTTPTransportTCP(), HTTPRequest(), ) diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index d4e88573d7..196bd0e449 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -45,15 +45,15 @@ func QUICHandshakeOptionServerName(value string) QUICHandshakeOption { } // QUICHandshake returns a function performing QUIC handshakes. -func QUICHandshake(pool *ConnPool, options ...QUICHandshakeOption) Func[ +func QUICHandshake(rt Runtime, options ...QUICHandshakeOption) Func[ *Endpoint, *Maybe[*QUICConnection]] { // See https://github.com/ooni/probe/issues/2413 to understand // why we're using nil to force netxlite to use the cached // default Mozilla cert pool. f := &quicHandshakeFunc{ InsecureSkipVerify: false, - Pool: pool, RootCAs: nil, + Rt: rt, ServerName: "", } for _, option := range options { @@ -67,12 +67,12 @@ type quicHandshakeFunc struct { // InsecureSkipVerify allows to skip TLS verification. InsecureSkipVerify bool - // Pool is the ConnPool that owns us. - Pool *ConnPool - // RootCAs contains the Root CAs to use. RootCAs *x509.CertPool + // Rt is the Runtime that owns us. + Rt Runtime + // ServerName is the ServerName to handshake for. ServerName string @@ -124,7 +124,7 @@ func (f *quicHandshakeFunc) Apply( } // possibly track established conn for late close - f.Pool.MaybeTrack(closerConn) + f.Rt.MaybeTrackConn(closerConn) // stop the operation logger ol.Stop(err) diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 7e0a30cf9d..2396909bfb 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -29,7 +29,7 @@ func TestQUICHandshake(t *testing.T) { certpool.AddCert(&x509.Certificate{}) f := QUICHandshake( - &ConnPool{}, + NewMinimalRuntime(), QUICHandshakeOptionInsecureSkipVerify(true), QUICHandshakeOptionServerName("sni"), QUICHandshakeOptionRootCAs(certpool), @@ -99,9 +99,9 @@ func TestQUICHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - pool := &ConnPool{} + rt := NewMinimalRuntime() quicHandshake := &quicHandshakeFunc{ - Pool: pool, + Rt: rt, dialer: tt.dialer, ServerName: tt.sni, } @@ -120,7 +120,7 @@ func TestQUICHandshake(t *testing.T) { if res.State == nil || res.State.QUICConn != tt.expectConn { t.Fatal("unexpected conn") } - pool.Close() + rt.Close() if wasClosed != tt.closed { t.Fatalf("unexpected connection closed state: %v", wasClosed) } @@ -137,7 +137,7 @@ func TestQUICHandshake(t *testing.T) { } t.Run("with nil dialer", func(t *testing.T) { - quicHandshake := &quicHandshakeFunc{Pool: &ConnPool{}, dialer: nil} + quicHandshake := &quicHandshakeFunc{Rt: NewMinimalRuntime(), dialer: nil} endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "udp", @@ -173,7 +173,7 @@ func TestServerNameQUIC(t *testing.T) { Address: "example.com:123", Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Pool: &ConnPool{}, ServerName: sni} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime(), ServerName: sni} serverName := f.serverName(endpoint) if serverName != sni { t.Fatalf("unexpected server name: %s", serverName) @@ -187,7 +187,7 @@ func TestServerNameQUIC(t *testing.T) { Domain: domain, Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Pool: &ConnPool{}} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime()} serverName := f.serverName(endpoint) if serverName != domain { t.Fatalf("unexpected server name: %s", serverName) @@ -200,7 +200,7 @@ func TestServerNameQUIC(t *testing.T) { Address: hostaddr + ":123", Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Pool: &ConnPool{}} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime()} serverName := f.serverName(endpoint) if serverName != hostaddr { t.Fatalf("unexpected server name: %s", serverName) @@ -213,7 +213,7 @@ func TestServerNameQUIC(t *testing.T) { Address: ip, Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Pool: &ConnPool{}} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime()} serverName := f.serverName(endpoint) if serverName != "" { t.Fatalf("unexpected server name: %s", serverName) diff --git a/internal/dslx/runtimecore.go b/internal/dslx/runtimecore.go new file mode 100644 index 0000000000..caf7a680d6 --- /dev/null +++ b/internal/dslx/runtimecore.go @@ -0,0 +1,15 @@ +package dslx + +import ( + "io" +) + +// Runtime is the runtime in which we execute the DSL. +type Runtime interface { + // Close closes all the connection tracked using MaybeTrackConn. + Close() error + + // MaybeTrackConn tracks a connection such that it is closed + // when you call the Runtime's Close method. + MaybeTrackConn(conn io.Closer) +} diff --git a/internal/dslx/runtimeminimal.go b/internal/dslx/runtimeminimal.go new file mode 100644 index 0000000000..3559168ab2 --- /dev/null +++ b/internal/dslx/runtimeminimal.go @@ -0,0 +1,43 @@ +package dslx + +import ( + "io" + "sync" +) + +// NewMinimalRuntime creates a minimal [Runtime] implementation. +func NewMinimalRuntime() *MinimalRuntime { + return &MinimalRuntime{ + mu: sync.Mutex{}, + v: []io.Closer{}, + } +} + +// MinimalRuntime is a minimal [Runtime] implementation. +type MinimalRuntime struct { + mu sync.Mutex + v []io.Closer +} + +// MaybeTrackConn implements Runtime. +func (p *MinimalRuntime) MaybeTrackConn(conn io.Closer) { + if conn != nil { + defer p.mu.Unlock() + p.mu.Lock() + p.v = append(p.v, conn) + } +} + +// Close implements Runtime. +func (p *MinimalRuntime) Close() error { + // Implementation note: reverse order is such that we close TLS + // connections before we close the TCP connections they use. Hence + // we'll _gracefully_ close TLS connections. + defer p.mu.Unlock() + p.mu.Lock() + for idx := len(p.v) - 1; idx >= 0; idx-- { + _ = p.v[idx].Close() + } + p.v = nil // reset + return nil +} diff --git a/internal/dslx/connpool_test.go b/internal/dslx/runtimeminimal_test.go similarity index 76% rename from internal/dslx/connpool_test.go rename to internal/dslx/runtimeminimal_test.go index daba7799f5..162cbee0bc 100644 --- a/internal/dslx/connpool_test.go +++ b/internal/dslx/runtimeminimal_test.go @@ -16,7 +16,7 @@ Test cases: - with connection - with quic connection -- Close ConnPool: +- Close MinimalRuntime: - all Close() calls succeed - one Close() call fails */ @@ -39,36 +39,37 @@ func closeableQUICConnWithErr(err error) io.Closer { } } -func TestConnPool(t *testing.T) { - type connpoolTest struct { +func TestMinimalRuntime(t *testing.T) { + // testcase is a test case implemented by this function + type testcase struct { mockConn io.Closer - want int // len of connpool.v + want int // len of (*minimalRuntime).v } t.Run("Maybe track connections", func(t *testing.T) { - tests := map[string]connpoolTest{ + tests := map[string]testcase{ "with nil": {mockConn: nil, want: 0}, "with connection": {mockConn: closeableConnWithErr(nil), want: 1}, "with quic connection": {mockConn: closeableQUICConnWithErr(nil), want: 1}, } for name, tt := range tests { t.Run(name, func(t *testing.T) { - connpool := &ConnPool{} - connpool.MaybeTrack(tt.mockConn) - if len(connpool.v) != tt.want { - t.Fatalf("expected %d tracked connections, got: %d", tt.want, len(connpool.v)) + rt := NewMinimalRuntime() + rt.MaybeTrackConn(tt.mockConn) + if len(rt.v) != tt.want { + t.Fatalf("expected %d tracked connections, got: %d", tt.want, len(rt.v)) } }) } }) - t.Run("Close ConnPool", func(t *testing.T) { + t.Run("Close MinimalRuntime", func(t *testing.T) { mockErr := errors.New("mocked") tests := map[string]struct { - pool *ConnPool + rt *MinimalRuntime }{ "all Close() calls succeed": { - pool: &ConnPool{ + rt: &MinimalRuntime{ v: []io.Closer{ closeableConnWithErr(nil), closeableQUICConnWithErr(nil), @@ -76,7 +77,7 @@ func TestConnPool(t *testing.T) { }, }, "one Close() call fails": { - pool: &ConnPool{ + rt: &MinimalRuntime{ v: []io.Closer{ closeableConnWithErr(nil), closeableConnWithErr(mockErr), @@ -87,11 +88,11 @@ func TestConnPool(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - err := tt.pool.Close() + err := tt.rt.Close() if err != nil { // Close() should always return nil t.Fatalf("unexpected error %s", err) } - if tt.pool.v != nil { + if tt.rt.v != nil { t.Fatalf("v should be reset but is not") } }) diff --git a/internal/dslx/tcp.go b/internal/dslx/tcp.go index fe5d769000..c38c4c0cfa 100644 --- a/internal/dslx/tcp.go +++ b/internal/dslx/tcp.go @@ -17,15 +17,15 @@ import ( ) // TCPConnect returns a function that establishes TCP connections. -func TCPConnect(pool *ConnPool) Func[*Endpoint, *Maybe[*TCPConnection]] { - f := &tcpConnectFunc{pool, nil} +func TCPConnect(rt Runtime) Func[*Endpoint, *Maybe[*TCPConnection]] { + f := &tcpConnectFunc{nil, rt} return f } // tcpConnectFunc is a function that establishes TCP connections. type tcpConnectFunc struct { - p *ConnPool dialer model.Dialer // for testing + rt Runtime } // Apply applies the function to its arguments. @@ -55,7 +55,7 @@ func (f *tcpConnectFunc) Apply( conn, err := dialer.DialContext(ctx, "tcp", input.Address) // possibly register established conn for late close - f.p.MaybeTrack(conn) + f.rt.MaybeTrackConn(conn) // stop the operation logger ol.Stop(err) diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index 6a94ea35c5..b7982bbdb2 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -17,7 +17,7 @@ import ( func TestTCPConnect(t *testing.T) { t.Run("Get tcpConnectFunc", func(t *testing.T) { f := TCPConnect( - &ConnPool{}, + NewMinimalRuntime(), ) if _, ok := f.(*tcpConnectFunc); !ok { t.Fatal("unexpected type. Expected: tcpConnectFunc") @@ -69,8 +69,8 @@ func TestTCPConnect(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - pool := &ConnPool{} - tcpConnect := &tcpConnectFunc{pool, tt.dialer} + rt := NewMinimalRuntime() + tcpConnect := &tcpConnectFunc{tt.dialer, rt} endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "tcp", @@ -86,7 +86,7 @@ func TestTCPConnect(t *testing.T) { if res.State == nil || res.State.Conn != tt.expectConn { t.Fatal("unexpected conn") } - pool.Close() + rt.Close() if wasClosed != tt.closed { t.Fatalf("unexpected connection closed state: %v", wasClosed) } @@ -107,7 +107,7 @@ func TestTCPConnect(t *testing.T) { // Make sure we get a valid dialer if no mocked dialer is configured func TestDialerOrDefault(t *testing.T) { f := &tcpConnectFunc{ - p: &ConnPool{}, + rt: NewMinimalRuntime(), dialer: nil, } dialer := f.dialerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index eac86bc923..c9717b57c8 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -50,7 +50,7 @@ func TLSHandshakeOptionServerName(value string) TLSHandshakeOption { } // TLSHandshake returns a function performing TSL handshakes. -func TLSHandshake(pool *ConnPool, options ...TLSHandshakeOption) Func[ +func TLSHandshake(rt Runtime, options ...TLSHandshakeOption) Func[ *TCPConnection, *Maybe[*TLSConnection]] { // See https://github.com/ooni/probe/issues/2413 to understand // why we're using nil to force netxlite to use the cached @@ -58,8 +58,8 @@ func TLSHandshake(pool *ConnPool, options ...TLSHandshakeOption) Func[ f := &tlsHandshakeFunc{ InsecureSkipVerify: false, NextProto: []string{}, - Pool: pool, RootCAs: nil, + Rt: rt, ServerName: "", } for _, option := range options { @@ -76,12 +76,12 @@ type tlsHandshakeFunc struct { // NextProto contains the ALPNs to negotiate. NextProto []string - // Pool is the Pool that owns us. - Pool *ConnPool - // RootCAs contains the Root CAs to use. RootCAs *x509.CertPool + // Pool is the Pool that owns us. + Rt Runtime + // ServerName is the ServerName to handshake for. ServerName string @@ -127,7 +127,7 @@ func (f *tlsHandshakeFunc) Apply( conn, err := handshaker.Handshake(ctx, input.Conn, config) // possibly register established conn for late close - f.Pool.MaybeTrack(conn) + f.Rt.MaybeTrackConn(conn) // stop the operation logger ol.Stop(err) diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 3cba8f81d8..858834a0dc 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -31,7 +31,7 @@ func TestTLSHandshake(t *testing.T) { certpool.AddCert(&x509.Certificate{}) f := TLSHandshake( - &ConnPool{}, + NewMinimalRuntime(), TLSHandshakeOptionInsecureSkipVerify(true), TLSHandshakeOptionNextProto([]string{"h2"}), TLSHandshakeOptionServerName("sni"), @@ -133,10 +133,10 @@ func TestTLSHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - pool := &ConnPool{} + rt := NewMinimalRuntime() tlsHandshake := &tlsHandshakeFunc{ NextProto: tt.config.nextProtos, - Pool: pool, + Rt: rt, ServerName: tt.config.sni, handshaker: tt.handshaker, } @@ -163,7 +163,7 @@ func TestTLSHandshake(t *testing.T) { if res.State.Conn != tt.expectConn { t.Fatalf("unexpected conn %v", res.State.Conn) } - pool.Close() + rt.Close() if wasClosed != tt.closed { t.Fatalf("unexpected connection closed state %v", wasClosed) } @@ -188,7 +188,7 @@ func TestServerNameTLS(t *testing.T) { Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Pool: &ConnPool{}, + Rt: NewMinimalRuntime(), ServerName: sni, } serverName := f.serverName(&tcpConn) @@ -204,7 +204,7 @@ func TestServerNameTLS(t *testing.T) { Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Pool: &ConnPool{}, + Rt: NewMinimalRuntime(), } serverName := f.serverName(&tcpConn) if serverName != domain { @@ -218,7 +218,7 @@ func TestServerNameTLS(t *testing.T) { Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Pool: &ConnPool{}, + Rt: NewMinimalRuntime(), } serverName := f.serverName(&tcpConn) if serverName != hostaddr { @@ -232,7 +232,7 @@ func TestServerNameTLS(t *testing.T) { Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Pool: &ConnPool{}, + Rt: NewMinimalRuntime(), } serverName := f.serverName(&tcpConn) if serverName != "" { @@ -246,7 +246,7 @@ func TestHandshakerOrDefault(t *testing.T) { f := &tlsHandshakeFunc{ InsecureSkipVerify: false, NextProto: []string{}, - Pool: &ConnPool{}, + Rt: NewMinimalRuntime(), RootCAs: &x509.CertPool{}, ServerName: "", handshaker: nil, diff --git a/internal/tutorial/dslx/chapter02/README.md b/internal/tutorial/dslx/chapter02/README.md index 5722218143..a0e4f96b51 100644 --- a/internal/tutorial/dslx/chapter02/README.md +++ b/internal/tutorial/dslx/chapter02/README.md @@ -331,12 +331,12 @@ the protocol, address, and port three-tuple.) ``` -Next, we create a connection pool. This data structure helps us to manage -open connections and close them when `connpool.Close` is invoked. +Next, we create a minimal runtime. This data structure helps us to manage +open connections and close them when `rt.Close` is invoked. ```Go - connpool := &dslx.ConnPool{} - defer connpool.Close() + rt := dslx.NewMinimalRuntime() + defer rt.Close() ``` @@ -350,9 +350,9 @@ target SNI to be used within the TLS Client Hello. ```Go pipelineTarget := dslx.Compose2( - dslx.TCPConnect(connpool), + dslx.TCPConnect(rt), dslx.TLSHandshake( - connpool, + rt, dslx.TLSHandshakeOptionServerName(targetSNI), ), ) @@ -364,9 +364,9 @@ specify the *control* SNI to be used within the TLS Client Hello. ```Go pipelineControl := dslx.Compose2( - dslx.TCPConnect(connpool), + dslx.TCPConnect(rt), dslx.TLSHandshake( - connpool, + rt, dslx.TLSHandshakeOptionServerName(m.config.ControlSNI), ), ) diff --git a/internal/tutorial/dslx/chapter02/main.go b/internal/tutorial/dslx/chapter02/main.go index e956c0cacf..347a17d21e 100644 --- a/internal/tutorial/dslx/chapter02/main.go +++ b/internal/tutorial/dslx/chapter02/main.go @@ -332,12 +332,12 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { // ``` // - // Next, we create a connection pool. This data structure helps us to manage - // open connections and close them when `connpool.Close` is invoked. + // Next, we create a minimal runtime. This data structure helps us to manage + // open connections and close them when `rt.Close` is invoked. // // ```Go - connpool := &dslx.ConnPool{} - defer connpool.Close() + rt := dslx.NewMinimalRuntime() + defer rt.Close() // ``` // @@ -351,9 +351,9 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { // // ```Go pipelineTarget := dslx.Compose2( - dslx.TCPConnect(connpool), + dslx.TCPConnect(rt), dslx.TLSHandshake( - connpool, + rt, dslx.TLSHandshakeOptionServerName(targetSNI), ), ) @@ -365,9 +365,9 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { // // ```Go pipelineControl := dslx.Compose2( - dslx.TCPConnect(connpool), + dslx.TCPConnect(rt), dslx.TLSHandshake( - connpool, + rt, dslx.TLSHandshakeOptionServerName(m.config.ControlSNI), ), ) From 9921ff37af5bb953262c5105525ce138257b32d1 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 18 Oct 2023 07:31:52 +0200 Subject: [PATCH 02/10] refactor(dslx): use an abstract trace This diff builds on the previous diff and uses an abstract trace inside of the dslx package. By using an abstract trace, we can choose between using: - a runtime that collects observations, based on measurexlite; and - a minimal runtime that does not collect observations. To make the trace abstract, we need to modify measurexlite's trace such that it can be used as an interface. In turn, this means we need to update the throttling package such that it uses an abstract trace definition. Strictly speaking, we could have avoided introducing this abstraction, but it seems better to also use an abstract trace there, as it allows for improving the decoupling with measurexlite. --- internal/dslx/dns.go | 25 ++-- internal/dslx/dns_test.go | 15 ++- internal/dslx/httpcore.go | 16 +-- internal/dslx/integration_test.go | 4 +- internal/dslx/observations.go | 3 +- internal/dslx/quic.go | 7 +- internal/dslx/quic_test.go | 2 +- internal/dslx/runtimecore.go | 7 ++ internal/dslx/runtimemeasurex.go | 26 ++++ internal/dslx/runtimeminimal.go | 94 ++++++++++++++ internal/dslx/runtimeminimal_test.go | 115 ++++++++++++++++++ internal/dslx/tcp.go | 9 +- internal/dslx/tcp_test.go | 2 +- internal/dslx/tls.go | 7 +- internal/dslx/trace.go | 71 +++++++++++ .../webconnectivitylte/cleartextflow.go | 10 +- .../webconnectivitylte/secureflow.go | 10 +- internal/measurexlite/conn.go | 24 ++-- internal/measurexlite/dialer.go | 12 +- internal/measurexlite/dns.go | 16 +-- internal/measurexlite/dns_test.go | 8 +- internal/measurexlite/quic.go | 12 +- internal/measurexlite/tls.go | 12 +- internal/measurexlite/trace.go | 26 ++-- internal/measurexlite/trace_test.go | 4 +- internal/throttling/throttling.go | 33 +++-- internal/tutorial/dslx/chapter02/main.go | 20 +-- 27 files changed, 467 insertions(+), 123 deletions(-) create mode 100644 internal/dslx/runtimemeasurex.go create mode 100644 internal/dslx/trace.go diff --git a/internal/dslx/dns.go b/internal/dslx/dns.go index 63759c544f..66a9d45586 100644 --- a/internal/dslx/dns.go +++ b/internal/dslx/dns.go @@ -10,7 +10,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -121,7 +120,7 @@ type ResolvedAddresses struct { // Trace is the trace we're currently using. This struct is // created by the various Apply functions using values inside // the DomainToResolve to initialize the Trace. - Trace *measurexlite.Trace + Trace Trace // ZeroTime is the zero time of the measurement. We inherit this field // from the value inside the DomainToResolve. @@ -130,13 +129,14 @@ type ResolvedAddresses struct { // DNSLookupGetaddrinfo returns a function that resolves a domain name to // IP addresses using libc's getaddrinfo function. -func DNSLookupGetaddrinfo() Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { - return &dnsLookupGetaddrinfoFunc{} +func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { + return &dnsLookupGetaddrinfoFunc{nil, rt} } // dnsLookupGetaddrinfoFunc is the function returned by DNSLookupGetaddrinfo. type dnsLookupGetaddrinfoFunc struct { resolver model.Resolver // for testing + rt Runtime } // Apply implements Func. @@ -144,13 +144,13 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.rt.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) // start the operation logger ol := logx.NewOperationLogger( input.Logger, "[#%d] DNSLookup[getaddrinfo] %s", - trace.Index, + trace.Index(), input.Domain, ) @@ -189,9 +189,11 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( // DNSLookupUDP returns a function that resolves a domain name to // IP addresses using the given DNS-over-UDP resolver. -func DNSLookupUDP(resolver string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { +func DNSLookupUDP(rt Runtime, resolver string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { return &dnsLookupUDPFunc{ - Resolver: resolver, + Resolver: resolver, + mockResolver: nil, + rt: rt, } } @@ -200,6 +202,7 @@ type dnsLookupUDPFunc struct { // Resolver is the MANDATORY endpointed of the resolver to use. Resolver string mockResolver model.Resolver // for testing + rt Runtime } // Apply implements Func. @@ -207,13 +210,13 @@ func (f *dnsLookupUDPFunc) Apply( ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.rt.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) // start the operation logger ol := logx.NewOperationLogger( input.Logger, "[#%d] DNSLookup[%s/udp] %s", - trace.Index, + trace.Index(), f.Resolver, input.Domain, ) @@ -227,7 +230,7 @@ func (f *dnsLookupUDPFunc) Apply( if resolver == nil { resolver = trace.NewParallelUDPResolver( input.Logger, - netxlite.NewDialerWithoutResolver(input.Logger), + trace.NewDialerWithoutResolver(input.Logger), f.Resolver, ) } diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index 2f3e292804..2868891291 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -67,7 +67,7 @@ Test cases: */ func TestGetaddrinfo(t *testing.T) { t.Run("Get dnsLookupGetaddrinfoFunc", func(t *testing.T) { - f := DNSLookupGetaddrinfo() + f := DNSLookupGetaddrinfo(NewMinimalRuntime()) if _, ok := f.(*dnsLookupGetaddrinfoFunc); !ok { t.Fatal("unexpected type, want dnsLookupGetaddrinfoFunc") } @@ -83,7 +83,9 @@ func TestGetaddrinfo(t *testing.T) { } t.Run("with nil resolver", func(t *testing.T) { - f := dnsLookupGetaddrinfoFunc{} + f := dnsLookupGetaddrinfoFunc{ + rt: NewMinimalRuntime(), + } ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately cancel the lookup res := f.Apply(ctx, domain) @@ -101,6 +103,7 @@ func TestGetaddrinfo(t *testing.T) { resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return nil, mockedErr }}, + rt: NewMinimalRuntime(), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -122,6 +125,7 @@ func TestGetaddrinfo(t *testing.T) { resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return []string{"93.184.216.34"}, nil }}, + rt: NewRuntimeMeasurexLite(), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -153,7 +157,8 @@ Test cases: */ func TestLookupUDP(t *testing.T) { t.Run("Get dnsLookupUDPFunc", func(t *testing.T) { - f := DNSLookupUDP("1.1.1.1:53") + rt := NewMinimalRuntime() + f := DNSLookupUDP(rt, "1.1.1.1:53") if _, ok := f.(*dnsLookupUDPFunc); !ok { t.Fatal("unexpected type, want dnsLookupUDPFunc") } @@ -169,7 +174,7 @@ func TestLookupUDP(t *testing.T) { } t.Run("with nil resolver", func(t *testing.T) { - f := dnsLookupUDPFunc{Resolver: "1.1.1.1:53"} + f := dnsLookupUDPFunc{Resolver: "1.1.1.1:53", rt: NewMinimalRuntime()} ctx, cancel := context.WithCancel(context.Background()) cancel() res := f.Apply(ctx, domain) @@ -188,6 +193,7 @@ func TestLookupUDP(t *testing.T) { mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return nil, mockedErr }}, + rt: NewMinimalRuntime(), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -210,6 +216,7 @@ func TestLookupUDP(t *testing.T) { mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return []string{"93.184.216.34"}, nil }}, + rt: NewRuntimeMeasurexLite(), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { diff --git a/internal/dslx/httpcore.go b/internal/dslx/httpcore.go index b824f8f848..2ffb4281ab 100644 --- a/internal/dslx/httpcore.go +++ b/internal/dslx/httpcore.go @@ -48,7 +48,7 @@ type HTTPTransport struct { TLSNegotiatedProtocol string // Trace is the MANDATORY trace we're using. - Trace *measurexlite.Trace + Trace Trace // Transport is the MANDATORY HTTP transport we're using. Transport model.HTTPTransport @@ -164,7 +164,7 @@ func (f *httpRequestFunc) Apply( ol := logx.NewOperationLogger( input.Logger, "[#%d] HTTPRequest %s with %s/%s host=%s", - input.Trace.Index, + input.Trace.Index(), req.URL.String(), input.Address, input.Network, @@ -288,7 +288,7 @@ func (f *httpRequestFunc) do( req *http.Request, ) (*http.Response, []byte, []*Observations, error) { const maxbody = 1 << 19 // TODO(bassosimone): allow to configure this value? - started := input.Trace.TimeSince(input.Trace.ZeroTime) + started := input.Trace.TimeSince(input.Trace.ZeroTime()) // manually create a single 1-length observations structure because // the trace cannot automatically capture HTTP events @@ -298,7 +298,7 @@ func (f *httpRequestFunc) do( observations[0].NetworkEvents = append(observations[0].NetworkEvents, measurexlite.NewAnnotationArchivalNetworkEvent( - input.Trace.Index, + input.Trace.Index(), started, "http_transaction_start", input.Trace.Tags()..., @@ -321,11 +321,11 @@ func (f *httpRequestFunc) do( samples := sampler.ExtractSamples() observations[0].NetworkEvents = append(observations[0].NetworkEvents, samples...) } - finished := input.Trace.TimeSince(input.Trace.ZeroTime) + finished := input.Trace.TimeSince(input.Trace.ZeroTime()) observations[0].NetworkEvents = append(observations[0].NetworkEvents, measurexlite.NewAnnotationArchivalNetworkEvent( - input.Trace.Index, + input.Trace.Index(), finished, "http_transaction_done", input.Trace.Tags()..., @@ -333,7 +333,7 @@ func (f *httpRequestFunc) do( observations[0].Requests = append(observations[0].Requests, measurexlite.NewArchivalHTTPRequestResult( - input.Trace.Index, + input.Trace.Index(), started, input.Network, input.Address, @@ -380,7 +380,7 @@ type HTTPResponse struct { // Trace is the MANDATORY trace we're using. The trace is drained // when you call the Observations method. - Trace *measurexlite.Trace + Trace Trace // ZeroTime is the MANDATORY zero time of the measurement. ZeroTime time.Time diff --git a/internal/dslx/integration_test.go b/internal/dslx/integration_test.go index e37e0a40d0..8f2e0e5b01 100644 --- a/internal/dslx/integration_test.go +++ b/internal/dslx/integration_test.go @@ -31,8 +31,8 @@ func TestMakeSureWeCollectSpeedSamples(t *testing.T) { })) defer server.Close() - // instantiate a connection pool - rt := NewMinimalRuntime() + // instantiate a runtime + rt := NewRuntimeMeasurexLite() defer rt.Close() // create a measuring function diff --git a/internal/dslx/observations.go b/internal/dslx/observations.go index e1be1b611d..269e999463 100644 --- a/internal/dslx/observations.go +++ b/internal/dslx/observations.go @@ -5,7 +5,6 @@ package dslx // import ( - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -54,7 +53,7 @@ func ExtractObservations[T any](rs ...*Maybe[T]) (out []*Observations) { // maybeTraceToObservations returns the observations inside the // trace taking into account the case where trace is nil. -func maybeTraceToObservations(trace *measurexlite.Trace) (out []*Observations) { +func maybeTraceToObservations(trace Trace) (out []*Observations) { if trace != nil { out = append(out, &Observations{ NetworkEvents: trace.NetworkEvents(), diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index 196bd0e449..bc25b54240 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -14,7 +14,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/quic-go/quic-go" @@ -83,7 +82,7 @@ type quicHandshakeFunc struct { func (f *quicHandshakeFunc) Apply( ctx context.Context, input *Endpoint) *Maybe[*QUICConnection] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.Rt.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) // use defaults or user-configured overrides serverName := f.serverName(input) @@ -92,7 +91,7 @@ func (f *quicHandshakeFunc) Apply( ol := logx.NewOperationLogger( input.Logger, "[#%d] QUICHandshake with %s SNI=%s", - trace.Index, + trace.Index(), input.Address, serverName, ) @@ -197,7 +196,7 @@ type QUICConnection struct { TLSState tls.ConnectionState // Trace is the MANDATORY trace we're using. - Trace *measurexlite.Trace + Trace Trace // ZeroTime is the MANDATORY zero time of the measurement. ZeroTime time.Time diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 2396909bfb..f1b8346ca0 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -99,7 +99,7 @@ func TestQUICHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewMinimalRuntime() + rt := NewRuntimeMeasurexLite() quicHandshake := &quicHandshakeFunc{ Rt: rt, dialer: tt.dialer, diff --git a/internal/dslx/runtimecore.go b/internal/dslx/runtimecore.go index caf7a680d6..f86e5a91e5 100644 --- a/internal/dslx/runtimecore.go +++ b/internal/dslx/runtimecore.go @@ -2,6 +2,7 @@ package dslx import ( "io" + "time" ) // Runtime is the runtime in which we execute the DSL. @@ -12,4 +13,10 @@ type Runtime interface { // MaybeTrackConn tracks a connection such that it is closed // when you call the Runtime's Close method. MaybeTrackConn(conn io.Closer) + + // NewTrace creates a [Trace] instance. Note that each [Runtime] + // creates its own [Trace] type. A [Trace] is not guaranteed to collect + // [*Observations]. For example, [NewMinimalRuntime] creates a [Runtime] + // that does not collect any [*Observations]. + NewTrace(index int64, zeroTime time.Time, tags ...string) Trace } diff --git a/internal/dslx/runtimemeasurex.go b/internal/dslx/runtimemeasurex.go new file mode 100644 index 0000000000..2ee833f3b9 --- /dev/null +++ b/internal/dslx/runtimemeasurex.go @@ -0,0 +1,26 @@ +package dslx + +import ( + "time" + + "github.com/ooni/probe-cli/v3/internal/measurexlite" +) + +// NewRuntimeMeasurexLite creates a [Runtime] using [measurexlite] to collect [*Observations]. +func NewRuntimeMeasurexLite() *RuntimeMeasurexLite { + return &RuntimeMeasurexLite{ + MinimalRuntime: NewMinimalRuntime(), + } +} + +// RuntimeMeasurexLite uses [measurexlite] to collect [*Observations.] +type RuntimeMeasurexLite struct { + *MinimalRuntime +} + +// NewTrace implements Runtime. +func (p *RuntimeMeasurexLite) NewTrace(index int64, zeroTime time.Time, tags ...string) Trace { + return measurexlite.NewTrace(index, zeroTime, tags...) +} + +var _ Runtime = &RuntimeMeasurexLite{} diff --git a/internal/dslx/runtimeminimal.go b/internal/dslx/runtimeminimal.go index 3559168ab2..7b459a6512 100644 --- a/internal/dslx/runtimeminimal.go +++ b/internal/dslx/runtimeminimal.go @@ -3,9 +3,15 @@ package dslx import ( "io" "sync" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) // NewMinimalRuntime creates a minimal [Runtime] implementation. +// +// This [Runtime] implementation does not collect any [*Observations]. func NewMinimalRuntime() *MinimalRuntime { return &MinimalRuntime{ mu: sync.Mutex{}, @@ -13,6 +19,8 @@ func NewMinimalRuntime() *MinimalRuntime { } } +var _ Runtime = &MinimalRuntime{} + // MinimalRuntime is a minimal [Runtime] implementation. type MinimalRuntime struct { mu sync.Mutex @@ -41,3 +49,89 @@ func (p *MinimalRuntime) Close() error { p.v = nil // reset return nil } + +// NewTrace implements Runtime. +func (p *MinimalRuntime) NewTrace(index int64, zeroTime time.Time, tags ...string) Trace { + return &minimalTrace{idx: index, tags: tags, zt: zeroTime} +} + +type minimalTrace struct { + idx int64 + tags []string + zt time.Time +} + +// CloneBytesReceivedMap implements Trace. +func (tx *minimalTrace) CloneBytesReceivedMap() (out map[string]int64) { + return make(map[string]int64) +} + +// DNSLookupsFromRoundTrip implements Trace. +func (tx *minimalTrace) DNSLookupsFromRoundTrip() (out []*model.ArchivalDNSLookupResult) { + return []*model.ArchivalDNSLookupResult{} +} + +// Index implements Trace. +func (tx *minimalTrace) Index() int64 { + return tx.idx +} + +// NetworkEvents implements Trace. +func (tx *minimalTrace) NetworkEvents() (out []*model.ArchivalNetworkEvent) { + return []*model.ArchivalNetworkEvent{} +} + +// NewDialerWithoutResolver implements Trace. +func (tx *minimalTrace) NewDialerWithoutResolver(dl model.DebugLogger, wrappers ...model.DialerWrapper) model.Dialer { + return netxlite.NewDialerWithoutResolver(dl, wrappers...) +} + +// NewParallelUDPResolver implements Trace. +func (tx *minimalTrace) NewParallelUDPResolver(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { + return netxlite.NewParallelUDPResolver(logger, dialer, address) +} + +// NewQUICDialerWithoutResolver implements Trace. +func (tx *minimalTrace) NewQUICDialerWithoutResolver(listener model.UDPListener, dl model.DebugLogger, wrappers ...model.QUICDialerWrapper) model.QUICDialer { + return netxlite.NewQUICDialerWithoutResolver(listener, dl, wrappers...) +} + +// NewStdlibResolver implements Trace. +func (tx *minimalTrace) NewStdlibResolver(logger model.DebugLogger) model.Resolver { + return netxlite.NewStdlibResolver(logger) +} + +// NewTLSHandshakerStdlib implements Trace. +func (tx *minimalTrace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { + return netxlite.NewTLSHandshakerStdlib(dl) +} + +// QUICHandshakes implements Trace. +func (tx *minimalTrace) QUICHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) { + return []*model.ArchivalTLSOrQUICHandshakeResult{} +} + +// TCPConnects implements Trace. +func (tx *minimalTrace) TCPConnects() (out []*model.ArchivalTCPConnectResult) { + return []*model.ArchivalTCPConnectResult{} +} + +// TLSHandshakes implements Trace. +func (tx *minimalTrace) TLSHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) { + return []*model.ArchivalTLSOrQUICHandshakeResult{} +} + +// Tags implements Trace. +func (tx *minimalTrace) Tags() []string { + return tx.tags +} + +// TimeSince implements Trace. +func (tx *minimalTrace) TimeSince(t0 time.Time) time.Duration { + return time.Since(t0) +} + +// ZeroTime implements Trace. +func (tx *minimalTrace) ZeroTime() time.Time { + return tx.zt +} diff --git a/internal/dslx/runtimeminimal_test.go b/internal/dslx/runtimeminimal_test.go index 162cbee0bc..a75a4115b2 100644 --- a/internal/dslx/runtimeminimal_test.go +++ b/internal/dslx/runtimeminimal_test.go @@ -4,8 +4,11 @@ import ( "errors" "io" "testing" + "time" + "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/mocks" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/quic-go/quic-go" ) @@ -98,4 +101,116 @@ func TestMinimalRuntime(t *testing.T) { }) } }) + + t.Run("Trace", func(t *testing.T) { + tags := []string{"antani", "mascetti", "melandri"} + rt := NewMinimalRuntime() + now := time.Now() + trace := rt.NewTrace(10, now, tags...) + + t.Run("CloneBytesReceivedMap", func(t *testing.T) { + out := trace.CloneBytesReceivedMap() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length map") + } + }) + + t.Run("DNSLookupsFromRoundTrip", func(t *testing.T) { + out := trace.DNSLookupsFromRoundTrip() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length slice") + } + }) + + t.Run("Index", func(t *testing.T) { + out := trace.Index() + if out != 10 { + t.Fatal("expected 10, got", out) + } + }) + + t.Run("NetworkEvents", func(t *testing.T) { + out := trace.NetworkEvents() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length slice") + } + }) + + t.Run("NewDialerWithoutResolver", func(t *testing.T) { + out := trace.NewDialerWithoutResolver(model.DiscardLogger) + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("NewParallelUDPResolver", func(t *testing.T) { + out := trace.NewParallelUDPResolver(model.DiscardLogger, &mocks.Dialer{}, "8.8.8.8:53") + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("NewQUICDialerWithoutResolver", func(t *testing.T) { + out := trace.NewQUICDialerWithoutResolver(&mocks.UDPListener{}, model.DiscardLogger) + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("NewStdlibResolver", func(t *testing.T) { + out := trace.NewStdlibResolver(model.DiscardLogger) + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("NewTLSHandshakerStdlib", func(t *testing.T) { + out := trace.NewTLSHandshakerStdlib(model.DiscardLogger) + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("QUICHandshakes", func(t *testing.T) { + out := trace.QUICHandshakes() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length slice") + } + }) + + t.Run("TCPConnects", func(t *testing.T) { + out := trace.TCPConnects() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length slice") + } + }) + + t.Run("TLSHandshakes", func(t *testing.T) { + out := trace.TLSHandshakes() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length slice") + } + }) + + t.Run("Tags", func(t *testing.T) { + out := trace.Tags() + if diff := cmp.Diff(tags, out); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("TimeSince", func(t *testing.T) { + out := trace.TimeSince(now.Add(-10 * time.Second)) + if out == 0 { + t.Fatal("expected non-zero time") + } + }) + + t.Run("ZeroTime", func(t *testing.T) { + out := trace.ZeroTime() + if out.IsZero() { + t.Fatal("expected non-zero time") + } + }) + }) } diff --git a/internal/dslx/tcp.go b/internal/dslx/tcp.go index c38c4c0cfa..fff120c09f 100644 --- a/internal/dslx/tcp.go +++ b/internal/dslx/tcp.go @@ -11,7 +11,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -33,13 +32,13 @@ func (f *tcpConnectFunc) Apply( ctx context.Context, input *Endpoint) *Maybe[*TCPConnection] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.rt.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) // start the operation logger ol := logx.NewOperationLogger( input.Logger, "[#%d] TCPConnect %s", - trace.Index, + trace.Index(), input.Address, ) @@ -80,7 +79,7 @@ func (f *tcpConnectFunc) Apply( } // dialerOrDefault is the function used to obtain a dialer -func (f *tcpConnectFunc) dialerOrDefault(trace *measurexlite.Trace, logger model.Logger) model.Dialer { +func (f *tcpConnectFunc) dialerOrDefault(trace Trace, logger model.Logger) model.Dialer { dialer := f.dialer if dialer == nil { dialer = trace.NewDialerWithoutResolver(logger) @@ -110,7 +109,7 @@ type TCPConnection struct { Network string // Trace is the MANDATORY trace we're using. - Trace *measurexlite.Trace + Trace Trace // ZeroTime is the MANDATORY zero time of the measurement. ZeroTime time.Time diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index b7982bbdb2..d975df3421 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -69,7 +69,7 @@ func TestTCPConnect(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewMinimalRuntime() + rt := NewRuntimeMeasurexLite() tcpConnect := &tcpConnectFunc{tt.dialer, rt} endpoint := &Endpoint{ Address: "1.2.3.4:567", diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index c9717b57c8..0b310d4dfc 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -13,7 +13,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -103,7 +102,7 @@ func (f *tlsHandshakeFunc) Apply( ol := logx.NewOperationLogger( input.Logger, "[#%d] TLSHandshake with %s SNI=%s ALPN=%v", - trace.Index, + trace.Index(), input.Address, serverName, nextProto, @@ -153,7 +152,7 @@ func (f *tlsHandshakeFunc) Apply( } // handshakerOrDefault is the function used to obtain an handshaker -func (f *tlsHandshakeFunc) handshakerOrDefault(trace *measurexlite.Trace, logger model.Logger) model.TLSHandshaker { +func (f *tlsHandshakeFunc) handshakerOrDefault(trace Trace, logger model.Logger) model.TLSHandshaker { handshaker := f.handshaker if handshaker == nil { handshaker = trace.NewTLSHandshakerStdlib(logger) @@ -211,7 +210,7 @@ type TLSConnection struct { TLSState tls.ConnectionState // Trace is the MANDATORY trace we're using. - Trace *measurexlite.Trace + Trace Trace // ZeroTime is the MANDATORY zero time of the measurement. ZeroTime time.Time diff --git a/internal/dslx/trace.go b/internal/dslx/trace.go new file mode 100644 index 0000000000..09094712ac --- /dev/null +++ b/internal/dslx/trace.go @@ -0,0 +1,71 @@ +package dslx + +import ( + "time" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// Trace collects [*Observations] using tracing. Specific implementations +// of this interface may be engineered to collect no [*Observations] for +// efficiency (i.e., when you don't care about collecting [*Observations] +// but you still want to use this package). +type Trace interface { + // CloneBytesReceivedMap returns a clone of the internal bytes received map. The key of the + // map is a string following the "EPNT_ADDRESS PROTO" pattern where the "EPNT_ADDRESS" contains + // the endpoint address and "PROTO" is "tcp" or "udp". + CloneBytesReceivedMap() (out map[string]int64) + + // DNSLookupsFromRoundTrip returns all the DNS lookup results collected so far. + DNSLookupsFromRoundTrip() (out []*model.ArchivalDNSLookupResult) + + // Index returns the unique index used by this trace. + Index() int64 + + // NewDialerWithoutResolver is equivalent to netxlite.NewDialerWithoutResolver + // except that it returns a model.Dialer that uses this trace. + // + // Caveat: the dialer wrappers are there to implement the + // model.MeasuringNetwork interface, but they're not used by this function. + NewDialerWithoutResolver(dl model.DebugLogger, wrappers ...model.DialerWrapper) model.Dialer + + // NewParallelUDPResolver returns a possibly-trace-ware parallel UDP resolver + NewParallelUDPResolver(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver + + // NewQUICDialerWithoutResolver is equivalent to + // netxlite.NewQUICDialerWithoutResolver except that it returns a + // model.QUICDialer that uses this trace. + // + // Caveat: the dialer wrappers are there to implement the + // model.MeasuringNetwork interface, but they're not used by this function. + NewQUICDialerWithoutResolver(listener model.UDPListener, + dl model.DebugLogger, wrappers ...model.QUICDialerWrapper) model.QUICDialer + + // NewTLSHandshakerStdlib is equivalent to netxlite.NewTLSHandshakerStdlib + // except that it returns a model.TLSHandshaker that uses this trace. + NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker + + // NetworkEvents returns all the network events collected so far. + NetworkEvents() (out []*model.ArchivalNetworkEvent) + + // NewStdlibResolver returns a possibly-trace-ware system resolver. + NewStdlibResolver(logger model.DebugLogger) model.Resolver + + // QUICHandshakes collects all the QUIC handshake results collected so far. + QUICHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) + + // TCPConnects collects all the TCP connect results collected so far. + TCPConnects() (out []*model.ArchivalTCPConnectResult) + + // TLSHandshakes collects all the TLS handshake results collected so far. + TLSHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) + + // Tags returns the trace tags. + Tags() []string + + // TimeSince is equivalent to Trace.TimeNow().Sub(t0). + TimeSince(t0 time.Time) time.Duration + + // ZeroTime returns the "zero" time of this trace. + ZeroTime() time.Time +} diff --git a/internal/experiment/webconnectivitylte/cleartextflow.go b/internal/experiment/webconnectivitylte/cleartextflow.go index fb51486bf4..8ec7cc663c 100644 --- a/internal/experiment/webconnectivitylte/cleartextflow.go +++ b/internal/experiment/webconnectivitylte/cleartextflow.go @@ -242,9 +242,9 @@ func (t *CleartextFlow) newHTTPRequest(ctx context.Context) (*http.Request, erro func (t *CleartextFlow) httpTransaction(ctx context.Context, network, address, alpn string, txp model.HTTPTransport, req *http.Request, trace *measurexlite.Trace) (*http.Response, []byte, error) { const maxbody = 1 << 19 - started := trace.TimeSince(trace.ZeroTime) + started := trace.TimeSince(trace.ZeroTime()) t.TestKeys.AppendNetworkEvents(measurexlite.NewAnnotationArchivalNetworkEvent( - trace.Index, started, "http_transaction_start", + trace.Index(), started, "http_transaction_start", )) resp, err := txp.RoundTrip(req) var body []byte @@ -256,12 +256,12 @@ func (t *CleartextFlow) httpTransaction(ctx context.Context, network, address, a reader := io.LimitReader(resp.Body, maxbody) body, err = StreamAllContext(ctx, reader) } - finished := trace.TimeSince(trace.ZeroTime) + finished := trace.TimeSince(trace.ZeroTime()) t.TestKeys.AppendNetworkEvents(measurexlite.NewAnnotationArchivalNetworkEvent( - trace.Index, finished, "http_transaction_done", + trace.Index(), finished, "http_transaction_done", )) ev := measurexlite.NewArchivalHTTPRequestResult( - trace.Index, + trace.Index(), started, network, address, diff --git a/internal/experiment/webconnectivitylte/secureflow.go b/internal/experiment/webconnectivitylte/secureflow.go index 1f63c17434..6284eca649 100644 --- a/internal/experiment/webconnectivitylte/secureflow.go +++ b/internal/experiment/webconnectivitylte/secureflow.go @@ -297,9 +297,9 @@ func (t *SecureFlow) newHTTPRequest(ctx context.Context) (*http.Request, error) func (t *SecureFlow) httpTransaction(ctx context.Context, network, address, alpn string, txp model.HTTPTransport, req *http.Request, trace *measurexlite.Trace) (*http.Response, []byte, error) { const maxbody = 1 << 19 - started := trace.TimeSince(trace.ZeroTime) + started := trace.TimeSince(trace.ZeroTime()) t.TestKeys.AppendNetworkEvents(measurexlite.NewAnnotationArchivalNetworkEvent( - trace.Index, started, "http_transaction_start", + trace.Index(), started, "http_transaction_start", )) resp, err := txp.RoundTrip(req) var body []byte @@ -311,12 +311,12 @@ func (t *SecureFlow) httpTransaction(ctx context.Context, network, address, alpn reader := io.LimitReader(resp.Body, maxbody) body, err = StreamAllContext(ctx, reader) } - finished := trace.TimeSince(trace.ZeroTime) + finished := trace.TimeSince(trace.ZeroTime()) t.TestKeys.AppendNetworkEvents(measurexlite.NewAnnotationArchivalNetworkEvent( - trace.Index, finished, "http_transaction_done", + trace.Index(), finished, "http_transaction_done", )) ev := measurexlite.NewArchivalHTTPRequestResult( - trace.Index, + trace.Index(), started, network, address, diff --git a/internal/measurexlite/conn.go b/internal/measurexlite/conn.go index deb14293cf..54beade596 100644 --- a/internal/measurexlite/conn.go +++ b/internal/measurexlite/conn.go @@ -44,16 +44,16 @@ func (c *connTrace) Read(b []byte) (int, error) { // collect preliminary stats when the connection is surely active network := c.RemoteAddr().Network() addr := c.RemoteAddr().String() - started := c.tx.TimeSince(c.tx.ZeroTime) + started := c.tx.TimeSince(c.tx.ZeroTime()) // perform the underlying network operation count, err := c.Conn.Read(b) // emit the network event - finished := c.tx.TimeSince(c.tx.ZeroTime) + finished := c.tx.TimeSince(c.tx.ZeroTime()) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.ReadOperation, network, addr, count, + c.tx.Index(), started, netxlite.ReadOperation, network, addr, count, err, finished, c.tx.tags...): default: // buffer is full } @@ -101,14 +101,14 @@ func (tx *Trace) CloneBytesReceivedMap() (out map[string]int64) { func (c *connTrace) Write(b []byte) (int, error) { network := c.RemoteAddr().Network() addr := c.RemoteAddr().String() - started := c.tx.TimeSince(c.tx.ZeroTime) + started := c.tx.TimeSince(c.tx.ZeroTime()) count, err := c.Conn.Write(b) - finished := c.tx.TimeSince(c.tx.ZeroTime) + finished := c.tx.TimeSince(c.tx.ZeroTime()) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.WriteOperation, network, addr, count, + c.tx.Index(), started, netxlite.WriteOperation, network, addr, count, err, finished, c.tx.tags...): default: // buffer is full } @@ -143,17 +143,17 @@ type udpLikeConnTrace struct { // Read implements model.UDPLikeConn.ReadFrom and saves network events. func (c *udpLikeConnTrace) ReadFrom(b []byte) (int, net.Addr, error) { // record when we started measuring - started := c.tx.TimeSince(c.tx.ZeroTime) + started := c.tx.TimeSince(c.tx.ZeroTime()) // perform the network operation count, addr, err := c.UDPLikeConn.ReadFrom(b) // emit the network event - finished := c.tx.TimeSince(c.tx.ZeroTime) + finished := c.tx.TimeSince(c.tx.ZeroTime()) address := addrStringIfNotNil(addr) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.ReadFromOperation, "udp", address, count, + c.tx.Index(), started, netxlite.ReadFromOperation, "udp", address, count, err, finished, c.tx.tags...): default: // buffer is full } @@ -176,15 +176,15 @@ func (tx *Trace) maybeUpdateBytesReceivedMapUDPLikeConn(addr net.Addr, count int // Write implements model.UDPLikeConn.WriteTo and saves network events. func (c *udpLikeConnTrace) WriteTo(b []byte, addr net.Addr) (int, error) { - started := c.tx.TimeSince(c.tx.ZeroTime) + started := c.tx.TimeSince(c.tx.ZeroTime()) address := addr.String() count, err := c.UDPLikeConn.WriteTo(b, addr) - finished := c.tx.TimeSince(c.tx.ZeroTime) + finished := c.tx.TimeSince(c.tx.ZeroTime()) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.WriteToOperation, "udp", address, count, + c.tx.Index(), started, netxlite.WriteToOperation, "udp", address, count, err, finished, c.tx.tags...): default: // buffer is full } diff --git a/internal/measurexlite/dialer.go b/internal/measurexlite/dialer.go index e0837a6b66..8284ec2558 100644 --- a/internal/measurexlite/dialer.go +++ b/internal/measurexlite/dialer.go @@ -55,11 +55,11 @@ func (tx *Trace) OnConnectDone( // insert into the tcpConnect buffer select { case tx.tcpConnect <- NewArchivalTCPConnectResult( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), remoteAddr, err, - finished.Sub(tx.ZeroTime), + finished.Sub(tx.ZeroTime()), tx.tags..., ): default: // buffer is full @@ -69,14 +69,14 @@ func (tx *Trace) OnConnectDone( // see https://github.com/ooni/probe/issues/2254 select { case tx.networkEvent <- NewArchivalNetworkEvent( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), netxlite.ConnectOperation, "tcp", remoteAddr, 0, err, - finished.Sub(tx.ZeroTime), + finished.Sub(tx.ZeroTime()), tx.tags..., ): default: // buffer is full diff --git a/internal/measurexlite/dns.go b/internal/measurexlite/dns.go index 8c72544426..f450f33673 100644 --- a/internal/measurexlite/dns.go +++ b/internal/measurexlite/dns.go @@ -52,7 +52,7 @@ func (r *resolverTrace) CloseIdleConnections() { func (r *resolverTrace) emitResolveStart() { select { case r.tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - r.tx.Index, r.tx.TimeSince(r.tx.ZeroTime), "resolve_start", + r.tx.Index(), r.tx.TimeSince(r.tx.ZeroTime()), "resolve_start", r.tx.tags..., ): default: // buffer is full @@ -63,7 +63,7 @@ func (r *resolverTrace) emitResolveStart() { func (r *resolverTrace) emiteResolveDone() { select { case r.tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - r.tx.Index, r.tx.TimeSince(r.tx.ZeroTime), "resolve_done", + r.tx.Index(), r.tx.TimeSince(r.tx.ZeroTime()), "resolve_done", r.tx.tags..., ): default: // buffer is full @@ -109,12 +109,12 @@ func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.DebugLogger, URL s // OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) { - t := finished.Sub(tx.ZeroTime) + t := finished.Sub(tx.ZeroTime()) select { case tx.dnsLookup <- NewArchivalDNSLookupResultFromRoundTrip( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), reso, query, response, @@ -274,12 +274,12 @@ var ErrDelayedDNSResponseBufferFull = errors.New( // OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse func (tx *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) error { - t := finished.Sub(tx.ZeroTime) + t := finished.Sub(tx.ZeroTime()) select { case tx.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), txp, query, response, diff --git a/internal/measurexlite/dns_test.go b/internal/measurexlite/dns_test.go index aba71d1489..7b2cd48886 100644 --- a/internal/measurexlite/dns_test.go +++ b/internal/measurexlite/dns_test.go @@ -519,8 +519,8 @@ func TestDelayedDNSResponseWithTimeout(t *testing.T) { } for i := 0; i < events; i++ { // fill the trace - trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime), - txp, query, dnsResponse, addrs, nil, finished.Sub(trace.ZeroTime)) + trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index(), started.Sub(trace.ZeroTime()), + txp, query, dnsResponse, addrs, nil, finished.Sub(trace.ZeroTime())) } ctx, cancel := context.WithCancel(context.Background()) cancel() // we ensure that the context cancels before draining all the events @@ -566,8 +566,8 @@ func TestDelayedDNSResponseWithTimeout(t *testing.T) { return []byte{} }, } - trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime), - txp, query, dnsResponse, addrs, nil, finished.Sub(trace.ZeroTime)) + trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index(), started.Sub(trace.ZeroTime()), + txp, query, dnsResponse, addrs, nil, finished.Sub(trace.ZeroTime())) got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second) if len(got) != 1 { t.Fatal("unexpected output from trace") diff --git a/internal/measurexlite/quic.go b/internal/measurexlite/quic.go index 949068f505..6df2149c57 100644 --- a/internal/measurexlite/quic.go +++ b/internal/measurexlite/quic.go @@ -49,10 +49,10 @@ func (qdx *quicDialerTrace) CloseIdleConnections() { // OnQUICHandshakeStart implements model.Trace.OnQUICHandshakeStart func (tx *Trace) OnQUICHandshakeStart(now time.Time, remoteAddr string, config *quic.Config) { - t := now.Sub(tx.ZeroTime) + t := now.Sub(tx.ZeroTime()) select { case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - tx.Index, t, "quic_handshake_start", tx.tags...): + tx.Index(), t, "quic_handshake_start", tx.tags...): default: } } @@ -60,7 +60,7 @@ func (tx *Trace) OnQUICHandshakeStart(now time.Time, remoteAddr string, config * // OnQUICHandshakeDone implements model.Trace.OnQUICHandshakeDone func (tx *Trace) OnQUICHandshakeDone(started time.Time, remoteAddr string, qconn quic.EarlyConnection, config *tls.Config, err error, finished time.Time) { - t := finished.Sub(tx.ZeroTime) + t := finished.Sub(tx.ZeroTime()) state := tls.ConnectionState{} if qconn != nil { @@ -69,8 +69,8 @@ func (tx *Trace) OnQUICHandshakeDone(started time.Time, remoteAddr string, qconn select { case tx.quicHandshake <- NewArchivalTLSOrQUICHandshakeResult( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), "udp", remoteAddr, config, @@ -84,7 +84,7 @@ func (tx *Trace) OnQUICHandshakeDone(started time.Time, remoteAddr string, qconn select { case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - tx.Index, t, "quic_handshake_done", tx.tags...): + tx.Index(), t, "quic_handshake_done", tx.tags...): default: // buffer is full } } diff --git a/internal/measurexlite/tls.go b/internal/measurexlite/tls.go index af0850c2e3..636bdae3aa 100644 --- a/internal/measurexlite/tls.go +++ b/internal/measurexlite/tls.go @@ -41,10 +41,10 @@ func (thx *tlsHandshakerTrace) Handshake( // OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart. func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) { - t := now.Sub(tx.ZeroTime) + t := now.Sub(tx.ZeroTime()) select { case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - tx.Index, t, "tls_handshake_start", tx.tags...): + tx.Index(), t, "tls_handshake_start", tx.tags...): default: // buffer is full } } @@ -52,12 +52,12 @@ func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *t // OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone. func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) { - t := finished.Sub(tx.ZeroTime) + t := finished.Sub(tx.ZeroTime()) select { case tx.tlsHandshake <- NewArchivalTLSOrQUICHandshakeResult( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), "tcp", remoteAddr, config, @@ -71,7 +71,7 @@ func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config select { case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - tx.Index, t, "tls_handshake_done", tx.tags...): + tx.Index(), t, "tls_handshake_done", tx.tags...): default: // buffer is full } } diff --git a/internal/measurexlite/trace.go b/internal/measurexlite/trace.go index 7a7e79a6f8..91187fc89c 100644 --- a/internal/measurexlite/trace.go +++ b/internal/measurexlite/trace.go @@ -25,10 +25,8 @@ import ( // // [step-by-step measurements]: https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md type Trace struct { - // Index is the unique index of this trace within the - // current measurement. Note that this field MUST be read-only. Writing it - // once you have constructed a trace MAY lead to data races. - Index int64 + // index is the unique index of this trace within the current measurement. + index int64 // Netx is the network to use for measuring. The constructor inits this // field using a [*netxlite.Netx]. You MAY override this field for testing. Make @@ -69,10 +67,8 @@ type Trace struct { // to produce deterministic timing when testing. timeNowFn func() time.Time - // ZeroTime is the time when we started the current measurement. This field - // MUST be read-only. Writing it once you have constructed the trace will - // likely read to data races. - ZeroTime time.Time + // zeroTime is the time when we started the current measurement. + zeroTime time.Time } var _ model.MeasuringNetwork = &Trace{} @@ -111,7 +107,7 @@ const QUICHandshakeBufferSize = 8 // to identify that some traces belong to some submeasurements). func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { return &Trace{ - Index: index, + index: index, Netx: &netxlite.Netx{Underlying: nil}, // use the host network bytesReceivedMap: make(map[string]int64), bytesReceivedMu: &sync.Mutex{}, @@ -141,10 +137,20 @@ func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { ), tags: tags, timeNowFn: nil, // use default - ZeroTime: zeroTime, + zeroTime: zeroTime, } } +// Index returns the trace index. +func (tx *Trace) Index() int64 { + return tx.index +} + +// ZeroTime returns trace's zero time. +func (tx *Trace) ZeroTime() time.Time { + return tx.zeroTime +} + // TimeNow implements model.Trace.TimeNow. func (tx *Trace) TimeNow() time.Time { if tx.timeNowFn != nil { diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go index 111a5274ea..686d6540d7 100644 --- a/internal/measurexlite/trace_test.go +++ b/internal/measurexlite/trace_test.go @@ -25,7 +25,7 @@ func TestNewTrace(t *testing.T) { trace := NewTrace(index, zeroTime) t.Run("Index", func(t *testing.T) { - if trace.Index != index { + if trace.Index() != index { t.Fatal("invalid index") } }) @@ -164,7 +164,7 @@ func TestNewTrace(t *testing.T) { }) t.Run("ZeroTime", func(t *testing.T) { - if !trace.ZeroTime.Equal(zeroTime) { + if !trace.ZeroTime().Equal(zeroTime) { t.Fatal("invalid zero time") } }) diff --git a/internal/throttling/throttling.go b/internal/throttling/throttling.go index 46252ed33f..a9de7e3800 100644 --- a/internal/throttling/throttling.go +++ b/internal/throttling/throttling.go @@ -7,13 +7,32 @@ import ( "sync" "time" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/memoryless" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/runtimex" ) -// Sampler periodically samples the bytes sent and received by a [*measurexlite.Trace]. The zero +// Trace is the [*measurexlite.Trace] abstraction used by this package. +type Trace interface { + // CloneBytesReceivedMap returns a clone of the internal bytes received map. The key of the + // map is a string following the "EPNT_ADDRESS PROTO" pattern where the "EPNT_ADDRESS" contains + // the endpoint address and "PROTO" is "tcp" or "udp". + CloneBytesReceivedMap() (out map[string]int64) + + // Index returns the unique index used by this trace. + Index() int64 + + // Tags returns the trace tags. + Tags() []string + + // TimeSince is equivalent to Trace.TimeNow().Sub(t0). + TimeSince(t0 time.Time) time.Duration + + // ZeroTime returns the "zero" time of this trace. + ZeroTime() time.Time +} + +// Sampler periodically samples the bytes sent and received by a [Trace]. The zero // value of this structure is invalid; please, construct using [NewSampler]. type Sampler struct { // cancel tells the background goroutine to stop @@ -29,16 +48,16 @@ type Sampler struct { q []*model.ArchivalNetworkEvent // tx is the trace we are sampling from - tx *measurexlite.Trace + tx Trace // wg is the waitgroup to wait for the sampler to join wg *sync.WaitGroup } -// NewSampler attaches a [*Sampler] to a [*measurexlite.Trace], starts sampling in the +// NewSampler attaches a [*Sampler] to a [Trace], starts sampling in the // background and returns the [*Sampler]. Remember to call [*Sampler.Close] to stop // the background goroutine that performs the sampling. -func NewSampler(tx *measurexlite.Trace) *Sampler { +func NewSampler(tx Trace) *Sampler { ctx, cancel := context.WithCancel(context.Background()) smpl := &Sampler{ cancel: cancel, @@ -95,7 +114,7 @@ const BytesReceivedCumulativeOperation = "bytes_received_cumulative" func (smpl *Sampler) collectSnapshot(stats map[string]int64) { // compute just once the events sampling time - now := smpl.tx.TimeSince(smpl.tx.ZeroTime).Seconds() + now := smpl.tx.TimeSince(smpl.tx.ZeroTime()).Seconds() // process each entry for key, count := range stats { @@ -116,7 +135,7 @@ func (smpl *Sampler) collectSnapshot(stats map[string]int64) { Proto: network, T0: now, T: now, - TransactionID: smpl.tx.Index, + TransactionID: smpl.tx.Index(), Tags: smpl.tx.Tags(), } diff --git a/internal/tutorial/dslx/chapter02/main.go b/internal/tutorial/dslx/chapter02/main.go index 347a17d21e..b188155c47 100644 --- a/internal/tutorial/dslx/chapter02/main.go +++ b/internal/tutorial/dslx/chapter02/main.go @@ -255,13 +255,22 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { dslx.DNSLookupOptionZeroTime(measurement.MeasurementStartTimeSaved), ) + // ``` + // + // Next, we create a minimal runtime. This data structure helps us to manage + // open connections and close them when `rt.Close` is invoked. + // + // ```Go + rt := dslx.NewMinimalRuntime() + defer rt.Close() + // ``` // // We construct the resolver dslx function which can be - like in this case - the // system resolver, or a custom UDP resolver. // // ```Go - lookupFn := dslx.DNSLookupGetaddrinfo() + lookupFn := dslx.DNSLookupGetaddrinfo(rt) // ``` // @@ -330,15 +339,6 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { runtimex.Assert(len(endpoints) >= 1, "expected at least one endpoint here") endpoint := endpoints[0] - // ``` - // - // Next, we create a minimal runtime. This data structure helps us to manage - // open connections and close them when `rt.Close` is invoked. - // - // ```Go - rt := dslx.NewMinimalRuntime() - defer rt.Close() - // ``` // // In the following we compose step-by-step measurement "pipelines", From b7755724e5b2ce7c89dec661319cd7f457b35052 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 18 Oct 2023 08:12:48 +0200 Subject: [PATCH 03/10] refactor(dslx): Runtime includes logger, ID generator, and zero time Currently, we pass these fields to each DSL function. However, if we want to load the functions from JSON, which is something we have experimented with in the richer input context, we can't do this. These fields should instead belong to the runtime. A subsequent diff will modify the DSL functions to take them from the runtime. --- internal/dslx/dns_test.go | 16 +++++------ internal/dslx/integration_test.go | 2 +- internal/dslx/quic_test.go | 14 +++++----- internal/dslx/runtimecore.go | 14 ++++++++++ internal/dslx/runtimemeasurex.go | 5 ++-- internal/dslx/runtimeminimal.go | 32 ++++++++++++++++++---- internal/dslx/runtimeminimal_test.go | 28 +++++++++++++++++-- internal/dslx/tcp_test.go | 6 ++-- internal/dslx/tls_test.go | 14 +++++----- internal/tutorial/dslx/chapter02/README.md | 22 ++++++++------- internal/tutorial/dslx/chapter02/main.go | 4 ++- 11 files changed, 111 insertions(+), 46 deletions(-) diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index 2868891291..079bc816b0 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -67,7 +67,7 @@ Test cases: */ func TestGetaddrinfo(t *testing.T) { t.Run("Get dnsLookupGetaddrinfoFunc", func(t *testing.T) { - f := DNSLookupGetaddrinfo(NewMinimalRuntime()) + f := DNSLookupGetaddrinfo(NewMinimalRuntime(model.DiscardLogger, time.Now())) if _, ok := f.(*dnsLookupGetaddrinfoFunc); !ok { t.Fatal("unexpected type, want dnsLookupGetaddrinfoFunc") } @@ -84,7 +84,7 @@ func TestGetaddrinfo(t *testing.T) { t.Run("with nil resolver", func(t *testing.T) { f := dnsLookupGetaddrinfoFunc{ - rt: NewMinimalRuntime(), + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately cancel the lookup @@ -103,7 +103,7 @@ func TestGetaddrinfo(t *testing.T) { resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return nil, mockedErr }}, - rt: NewMinimalRuntime(), + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -125,7 +125,7 @@ func TestGetaddrinfo(t *testing.T) { resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return []string{"93.184.216.34"}, nil }}, - rt: NewRuntimeMeasurexLite(), + rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -157,7 +157,7 @@ Test cases: */ func TestLookupUDP(t *testing.T) { t.Run("Get dnsLookupUDPFunc", func(t *testing.T) { - rt := NewMinimalRuntime() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) f := DNSLookupUDP(rt, "1.1.1.1:53") if _, ok := f.(*dnsLookupUDPFunc); !ok { t.Fatal("unexpected type, want dnsLookupUDPFunc") @@ -174,7 +174,7 @@ func TestLookupUDP(t *testing.T) { } t.Run("with nil resolver", func(t *testing.T) { - f := dnsLookupUDPFunc{Resolver: "1.1.1.1:53", rt: NewMinimalRuntime()} + f := dnsLookupUDPFunc{Resolver: "1.1.1.1:53", rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} ctx, cancel := context.WithCancel(context.Background()) cancel() res := f.Apply(ctx, domain) @@ -193,7 +193,7 @@ func TestLookupUDP(t *testing.T) { mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return nil, mockedErr }}, - rt: NewMinimalRuntime(), + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -216,7 +216,7 @@ func TestLookupUDP(t *testing.T) { mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return []string{"93.184.216.34"}, nil }}, - rt: NewRuntimeMeasurexLite(), + rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { diff --git a/internal/dslx/integration_test.go b/internal/dslx/integration_test.go index 8f2e0e5b01..e199a59cea 100644 --- a/internal/dslx/integration_test.go +++ b/internal/dslx/integration_test.go @@ -32,7 +32,7 @@ func TestMakeSureWeCollectSpeedSamples(t *testing.T) { defer server.Close() // instantiate a runtime - rt := NewRuntimeMeasurexLite() + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) defer rt.Close() // create a measuring function diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index f1b8346ca0..13b7c24560 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -29,7 +29,7 @@ func TestQUICHandshake(t *testing.T) { certpool.AddCert(&x509.Certificate{}) f := QUICHandshake( - NewMinimalRuntime(), + NewMinimalRuntime(model.DiscardLogger, time.Now()), QUICHandshakeOptionInsecureSkipVerify(true), QUICHandshakeOptionServerName("sni"), QUICHandshakeOptionRootCAs(certpool), @@ -99,7 +99,7 @@ func TestQUICHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewRuntimeMeasurexLite() + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) quicHandshake := &quicHandshakeFunc{ Rt: rt, dialer: tt.dialer, @@ -137,7 +137,7 @@ func TestQUICHandshake(t *testing.T) { } t.Run("with nil dialer", func(t *testing.T) { - quicHandshake := &quicHandshakeFunc{Rt: NewMinimalRuntime(), dialer: nil} + quicHandshake := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), dialer: nil} endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "udp", @@ -173,7 +173,7 @@ func TestServerNameQUIC(t *testing.T) { Address: "example.com:123", Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(), ServerName: sni} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni} serverName := f.serverName(endpoint) if serverName != sni { t.Fatalf("unexpected server name: %s", serverName) @@ -187,7 +187,7 @@ func TestServerNameQUIC(t *testing.T) { Domain: domain, Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime()} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} serverName := f.serverName(endpoint) if serverName != domain { t.Fatalf("unexpected server name: %s", serverName) @@ -200,7 +200,7 @@ func TestServerNameQUIC(t *testing.T) { Address: hostaddr + ":123", Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime()} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} serverName := f.serverName(endpoint) if serverName != hostaddr { t.Fatalf("unexpected server name: %s", serverName) @@ -213,7 +213,7 @@ func TestServerNameQUIC(t *testing.T) { Address: ip, Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime()} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} serverName := f.serverName(endpoint) if serverName != "" { t.Fatalf("unexpected server name: %s", serverName) diff --git a/internal/dslx/runtimecore.go b/internal/dslx/runtimecore.go index f86e5a91e5..4215431e54 100644 --- a/internal/dslx/runtimecore.go +++ b/internal/dslx/runtimecore.go @@ -2,7 +2,10 @@ package dslx import ( "io" + "sync/atomic" "time" + + "github.com/ooni/probe-cli/v3/internal/model" ) // Runtime is the runtime in which we execute the DSL. @@ -10,6 +13,13 @@ type Runtime interface { // Close closes all the connection tracked using MaybeTrackConn. Close() error + // IDGenerator returns an atomic counter used to generate + // separate unique IDs for each trace. + IDGenerator() *atomic.Int64 + + // Logger returns the base logger to use. + Logger() model.Logger + // MaybeTrackConn tracks a connection such that it is closed // when you call the Runtime's Close method. MaybeTrackConn(conn io.Closer) @@ -19,4 +29,8 @@ type Runtime interface { // [*Observations]. For example, [NewMinimalRuntime] creates a [Runtime] // that does not collect any [*Observations]. NewTrace(index int64, zeroTime time.Time, tags ...string) Trace + + // ZeroTime returns the runtime's "zero" time, which is used as the + // starting point to generate observation's delta times. + ZeroTime() time.Time } diff --git a/internal/dslx/runtimemeasurex.go b/internal/dslx/runtimemeasurex.go index 2ee833f3b9..a075d085b5 100644 --- a/internal/dslx/runtimemeasurex.go +++ b/internal/dslx/runtimemeasurex.go @@ -4,12 +4,13 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/measurexlite" + "github.com/ooni/probe-cli/v3/internal/model" ) // NewRuntimeMeasurexLite creates a [Runtime] using [measurexlite] to collect [*Observations]. -func NewRuntimeMeasurexLite() *RuntimeMeasurexLite { +func NewRuntimeMeasurexLite(logger model.Logger, zeroTime time.Time) *RuntimeMeasurexLite { return &RuntimeMeasurexLite{ - MinimalRuntime: NewMinimalRuntime(), + MinimalRuntime: NewMinimalRuntime(logger, zeroTime), } } diff --git a/internal/dslx/runtimeminimal.go b/internal/dslx/runtimeminimal.go index 7b459a6512..7003c2ceec 100644 --- a/internal/dslx/runtimeminimal.go +++ b/internal/dslx/runtimeminimal.go @@ -3,6 +3,7 @@ package dslx import ( "io" "sync" + "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/model" @@ -12,10 +13,13 @@ import ( // NewMinimalRuntime creates a minimal [Runtime] implementation. // // This [Runtime] implementation does not collect any [*Observations]. -func NewMinimalRuntime() *MinimalRuntime { +func NewMinimalRuntime(logger model.Logger, zeroTime time.Time) *MinimalRuntime { return &MinimalRuntime{ - mu: sync.Mutex{}, - v: []io.Closer{}, + idg: &atomic.Int64{}, + logger: logger, + mu: sync.Mutex{}, + v: []io.Closer{}, + zeroT: zeroTime, } } @@ -23,8 +27,26 @@ var _ Runtime = &MinimalRuntime{} // MinimalRuntime is a minimal [Runtime] implementation. type MinimalRuntime struct { - mu sync.Mutex - v []io.Closer + idg *atomic.Int64 + logger model.Logger + mu sync.Mutex + v []io.Closer + zeroT time.Time +} + +// IDGenerator implements Runtime. +func (p *MinimalRuntime) IDGenerator() *atomic.Int64 { + return p.idg +} + +// Logger implements Runtime. +func (p *MinimalRuntime) Logger() model.Logger { + return p.logger +} + +// ZeroTime implements Runtime. +func (p *MinimalRuntime) ZeroTime() time.Time { + return p.zeroT } // MaybeTrackConn implements Runtime. diff --git a/internal/dslx/runtimeminimal_test.go b/internal/dslx/runtimeminimal_test.go index a75a4115b2..4699787fb9 100644 --- a/internal/dslx/runtimeminimal_test.go +++ b/internal/dslx/runtimeminimal_test.go @@ -57,7 +57,7 @@ func TestMinimalRuntime(t *testing.T) { } for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewMinimalRuntime() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) rt.MaybeTrackConn(tt.mockConn) if len(rt.v) != tt.want { t.Fatalf("expected %d tracked connections, got: %d", tt.want, len(rt.v)) @@ -102,9 +102,33 @@ func TestMinimalRuntime(t *testing.T) { } }) + t.Run("IDGenerator", func(t *testing.T) { + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + out := rt.IDGenerator() + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("Logger", func(t *testing.T) { + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + out := rt.Logger() + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("ZeroTime", func(t *testing.T) { + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + out := rt.ZeroTime() + if out.IsZero() { + t.Fatal("expected non-zero time") + } + }) + t.Run("Trace", func(t *testing.T) { tags := []string{"antani", "mascetti", "melandri"} - rt := NewMinimalRuntime() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) now := time.Now() trace := rt.NewTrace(10, now, tags...) diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index d975df3421..6a0bf8569b 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -17,7 +17,7 @@ import ( func TestTCPConnect(t *testing.T) { t.Run("Get tcpConnectFunc", func(t *testing.T) { f := TCPConnect( - NewMinimalRuntime(), + NewMinimalRuntime(model.DiscardLogger, time.Now()), ) if _, ok := f.(*tcpConnectFunc); !ok { t.Fatal("unexpected type. Expected: tcpConnectFunc") @@ -69,7 +69,7 @@ func TestTCPConnect(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewRuntimeMeasurexLite() + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) tcpConnect := &tcpConnectFunc{tt.dialer, rt} endpoint := &Endpoint{ Address: "1.2.3.4:567", @@ -107,7 +107,7 @@ func TestTCPConnect(t *testing.T) { // Make sure we get a valid dialer if no mocked dialer is configured func TestDialerOrDefault(t *testing.T) { f := &tcpConnectFunc{ - rt: NewMinimalRuntime(), + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), dialer: nil, } dialer := f.dialerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 858834a0dc..40030f9187 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -31,7 +31,7 @@ func TestTLSHandshake(t *testing.T) { certpool.AddCert(&x509.Certificate{}) f := TLSHandshake( - NewMinimalRuntime(), + NewMinimalRuntime(model.DiscardLogger, time.Now()), TLSHandshakeOptionInsecureSkipVerify(true), TLSHandshakeOptionNextProto([]string{"h2"}), TLSHandshakeOptionServerName("sni"), @@ -133,7 +133,7 @@ func TestTLSHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewMinimalRuntime() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) tlsHandshake := &tlsHandshakeFunc{ NextProto: tt.config.nextProtos, Rt: rt, @@ -188,7 +188,7 @@ func TestServerNameTLS(t *testing.T) { Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(), + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni, } serverName := f.serverName(&tcpConn) @@ -204,7 +204,7 @@ func TestServerNameTLS(t *testing.T) { Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(), + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } serverName := f.serverName(&tcpConn) if serverName != domain { @@ -218,7 +218,7 @@ func TestServerNameTLS(t *testing.T) { Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(), + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } serverName := f.serverName(&tcpConn) if serverName != hostaddr { @@ -232,7 +232,7 @@ func TestServerNameTLS(t *testing.T) { Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(), + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } serverName := f.serverName(&tcpConn) if serverName != "" { @@ -246,7 +246,7 @@ func TestHandshakerOrDefault(t *testing.T) { f := &tlsHandshakeFunc{ InsecureSkipVerify: false, NextProto: []string{}, - Rt: NewMinimalRuntime(), + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), RootCAs: &x509.CertPool{}, ServerName: "", handshaker: nil, diff --git a/internal/tutorial/dslx/chapter02/README.md b/internal/tutorial/dslx/chapter02/README.md index a0e4f96b51..e9be23d453 100644 --- a/internal/tutorial/dslx/chapter02/README.md +++ b/internal/tutorial/dslx/chapter02/README.md @@ -44,7 +44,9 @@ import ( "errors" "net" "sync/atomic" + "time" + "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/dslx" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/runtimex" @@ -256,11 +258,20 @@ experiment's start time. ``` +Next, we create a minimal runtime. This data structure helps us to manage +open connections and close them when `rt.Close` is invoked. + +```Go + rt := dslx.NewMinimalRuntime(log.Log, time.Now()) + defer rt.Close() + +``` + We construct the resolver dslx function which can be - like in this case - the system resolver, or a custom UDP resolver. ```Go - lookupFn := dslx.DNSLookupGetaddrinfo() + lookupFn := dslx.DNSLookupGetaddrinfo(rt) ``` @@ -331,15 +342,6 @@ the protocol, address, and port three-tuple.) ``` -Next, we create a minimal runtime. This data structure helps us to manage -open connections and close them when `rt.Close` is invoked. - -```Go - rt := dslx.NewMinimalRuntime() - defer rt.Close() - -``` - In the following we compose step-by-step measurement "pipelines", represented by `dslx` functions. diff --git a/internal/tutorial/dslx/chapter02/main.go b/internal/tutorial/dslx/chapter02/main.go index b188155c47..50991a4012 100644 --- a/internal/tutorial/dslx/chapter02/main.go +++ b/internal/tutorial/dslx/chapter02/main.go @@ -45,7 +45,9 @@ import ( "errors" "net" "sync/atomic" + "time" + "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/dslx" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/runtimex" @@ -261,7 +263,7 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { // open connections and close them when `rt.Close` is invoked. // // ```Go - rt := dslx.NewMinimalRuntime() + rt := dslx.NewMinimalRuntime(log.Log, time.Now()) defer rt.Close() // ``` From 7a555f4cfafc0992f9f97fd162a9241ad1b43f39 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 18 Oct 2023 09:11:48 +0200 Subject: [PATCH 04/10] refactor(dslx): use Runtime for logger, ID generator, and zero time This diff builds upon the previous diff to use the Runtime to get the logger, ID generator, and zero time. By doing this, we make most structures that DSL functions takes as input or emit in output serializable and deserializable. --- internal/dslx/dns.go | 93 ++---------- internal/dslx/dns_test.go | 27 +--- internal/dslx/endpoint.go | 48 +----- internal/dslx/endpoint_test.go | 15 -- internal/dslx/http_test.go | 169 ++++++++++----------- internal/dslx/httpcore.go | 33 +--- internal/dslx/httpquic.go | 17 +-- internal/dslx/httptcp.go | 17 +-- internal/dslx/httptls.go | 17 +-- internal/dslx/integration_test.go | 16 +- internal/dslx/quic.go | 35 ++--- internal/dslx/quic_test.go | 21 +-- internal/dslx/tcp.go | 29 +--- internal/dslx/tcp_test.go | 10 +- internal/dslx/tls.go | 31 ++-- internal/dslx/tls_test.go | 15 +- internal/tutorial/dslx/chapter02/README.md | 21 +-- internal/tutorial/dslx/chapter02/main.go | 21 +-- 18 files changed, 184 insertions(+), 451 deletions(-) diff --git a/internal/dslx/dns.go b/internal/dslx/dns.go index 66a9d45586..9da9c9ed0f 100644 --- a/internal/dslx/dns.go +++ b/internal/dslx/dns.go @@ -6,7 +6,6 @@ package dslx import ( "context" - "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/logx" @@ -20,22 +19,6 @@ type DomainName string // DNSLookupOption is an option you can pass to NewDomainToResolve. type DNSLookupOption func(*DomainToResolve) -// DNSLookupOptionIDGenerator configures a specific ID generator. -// See DomainToResolve docs for more information. -func DNSLookupOptionIDGenerator(value *atomic.Int64) DNSLookupOption { - return func(dis *DomainToResolve) { - dis.IDGenerator = value - } -} - -// DNSLookupOptionLogger configures a specific logger. -// See DomainToResolve docs for more information. -func DNSLookupOptionLogger(value model.Logger) DNSLookupOption { - return func(dis *DomainToResolve) { - dis.Logger = value - } -} - // DNSLookupOptionTags allows to set tags to tag observations. func DNSLookupOptionTags(value ...string) DNSLookupOption { return func(dis *DomainToResolve) { @@ -43,24 +26,13 @@ func DNSLookupOptionTags(value ...string) DNSLookupOption { } } -// DNSLookupOptionZeroTime configures the measurement's zero time. -// See DomainToResolve docs for more information. -func DNSLookupOptionZeroTime(value time.Time) DNSLookupOption { - return func(dis *DomainToResolve) { - dis.ZeroTime = value - } -} - // NewDomainToResolve creates input for performing DNS lookups. The only mandatory // argument is the domain name to resolve. You can also supply optional // values by passing options to this function. func NewDomainToResolve(domain DomainName, options ...DNSLookupOption) *DomainToResolve { state := &DomainToResolve{ - Domain: string(domain), - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - Tags: []string{}, - ZeroTime: time.Now(), + Domain: string(domain), + Tags: []string{}, } for _, option := range options { option(state) @@ -78,25 +50,8 @@ type DomainToResolve struct { // Domain is the MANDATORY domain name to lookup. Domain string - // IDGenerator is the MANDATORY ID generator. We will use this field - // to assign unique IDs to distinct sub-measurements. The default - // construction implemented by NewDomainToResolve creates a new generator - // that starts counting from zero, leading to the first trace having - // one as its index. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. The default construction - // implemented by NewDomainToResolve uses model.DiscardLogger. - Logger model.Logger - // Tags contains OPTIONAL tags to tag observations. Tags []string - - // ZeroTime is the MANDATORY zero time of the measurement. We will - // use this field as the zero value to compute relative elapsed times - // when generating measurements. The default construction by - // NewDomainToResolve initializes this field with the current time. - ZeroTime time.Time } // ResolvedAddresses contains the results of DNS lookups. To initialize @@ -109,22 +64,10 @@ type ResolvedAddresses struct { // from the value inside the DomainToResolve. Domain string - // IDGenerator is the ID generator. We inherit this field - // from the value inside the DomainToResolve. - IDGenerator *atomic.Int64 - - // Logger is the logger to use. We inherit this field - // from the value inside the DomainToResolve. - Logger model.Logger - // Trace is the trace we're currently using. This struct is // created by the various Apply functions using values inside // the DomainToResolve to initialize the Trace. Trace Trace - - // ZeroTime is the zero time of the measurement. We inherit this field - // from the value inside the DomainToResolve. - ZeroTime time.Time } // DNSLookupGetaddrinfo returns a function that resolves a domain name to @@ -144,11 +87,11 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { // create trace - trace := f.rt.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.rt.NewTrace(f.rt.IDGenerator().Add(1), f.rt.ZeroTime(), input.Tags...) // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.rt.Logger(), "[#%d] DNSLookup[getaddrinfo] %s", trace.Index(), input.Domain, @@ -161,7 +104,7 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( resolver := f.resolver if resolver == nil { - resolver = trace.NewStdlibResolver(input.Logger) + resolver = trace.NewStdlibResolver(f.rt.Logger()) } // lookup @@ -171,12 +114,9 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( ol.Stop(err) state := &ResolvedAddresses{ - Addresses: addrs, // maybe empty - Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, - Trace: trace, - ZeroTime: input.ZeroTime, + Addresses: addrs, // maybe empty + Domain: input.Domain, + Trace: trace, } return &Maybe[*ResolvedAddresses]{ @@ -210,11 +150,11 @@ func (f *dnsLookupUDPFunc) Apply( ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { // create trace - trace := f.rt.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.rt.NewTrace(f.rt.IDGenerator().Add(1), f.rt.ZeroTime(), input.Tags...) // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.rt.Logger(), "[#%d] DNSLookup[%s/udp] %s", trace.Index(), f.Resolver, @@ -229,8 +169,8 @@ func (f *dnsLookupUDPFunc) Apply( resolver := f.mockResolver if resolver == nil { resolver = trace.NewParallelUDPResolver( - input.Logger, - trace.NewDialerWithoutResolver(input.Logger), + f.rt.Logger(), + trace.NewDialerWithoutResolver(f.rt.Logger()), f.Resolver, ) } @@ -242,12 +182,9 @@ func (f *dnsLookupUDPFunc) Apply( ol.Stop(err) state := &ResolvedAddresses{ - Addresses: addrs, // maybe empty - Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, - Trace: trace, - ZeroTime: input.ZeroTime, + Addresses: addrs, // maybe empty + Domain: input.Domain, + Trace: trace, } return &Maybe[*ResolvedAddresses]{ diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index 079bc816b0..15f08e155e 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -30,26 +30,13 @@ func TestNewDomainToResolve(t *testing.T) { t.Run("with options", func(t *testing.T) { idGen := &atomic.Int64{} idGen.Add(42) - zt := time.Now() domainToResolve := NewDomainToResolve( DomainName("www.example.com"), - DNSLookupOptionIDGenerator(idGen), - DNSLookupOptionLogger(model.DiscardLogger), - DNSLookupOptionZeroTime(zt), DNSLookupOptionTags("antani"), ) if domainToResolve.Domain != "www.example.com" { t.Fatalf("unexpected domain") } - if domainToResolve.IDGenerator != idGen { - t.Fatalf("unexpected id generator") - } - if domainToResolve.Logger != model.DiscardLogger { - t.Fatalf("unexpected logger") - } - if domainToResolve.ZeroTime != zt { - t.Fatalf("unexpected zerotime") - } if diff := cmp.Diff([]string{"antani"}, domainToResolve.Tags); diff != "" { t.Fatal(diff) } @@ -75,11 +62,8 @@ func TestGetaddrinfo(t *testing.T) { t.Run("Apply dnsLookupGetaddrinfoFunc", func(t *testing.T) { domain := &DomainToResolve{ - Domain: "example.com", - Logger: model.DiscardLogger, - IDGenerator: &atomic.Int64{}, - Tags: []string{"antani"}, - ZeroTime: time.Time{}, + Domain: "example.com", + Tags: []string{"antani"}, } t.Run("with nil resolver", func(t *testing.T) { @@ -166,11 +150,8 @@ func TestLookupUDP(t *testing.T) { t.Run("Apply dnsLookupGetaddrinfoFunc", func(t *testing.T) { domain := &DomainToResolve{ - Domain: "example.com", - Logger: model.DiscardLogger, - IDGenerator: &atomic.Int64{}, - Tags: []string{"antani"}, - ZeroTime: time.Time{}, + Domain: "example.com", + Tags: []string{"antani"}, } t.Run("with nil resolver", func(t *testing.T) { diff --git a/internal/dslx/endpoint.go b/internal/dslx/endpoint.go index cd17d68c88..ea725f5d58 100644 --- a/internal/dslx/endpoint.go +++ b/internal/dslx/endpoint.go @@ -4,13 +4,6 @@ package dslx // Manipulate endpoints // -import ( - "sync/atomic" - "time" - - "github.com/ooni/probe-cli/v3/internal/model" -) - type ( // EndpointNetwork is the network of the endpoint EndpointNetwork string @@ -29,20 +22,11 @@ type Endpoint struct { // Domain is the OPTIONAL domain used to resolve the endpoints' IP address. Domain string - // IDGenerator is MANDATORY the ID generator to use. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY endpoint network. Network string // Tags contains OPTIONAL tags for tagging observations. Tags []string - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time } // EndpointOption is an option you can use to construct EndpointState. @@ -55,20 +39,6 @@ func EndpointOptionDomain(value string) EndpointOption { } } -// EndpointOptionIDGenerator allows to set the ID generator. -func EndpointOptionIDGenerator(value *atomic.Int64) EndpointOption { - return func(es *Endpoint) { - es.IDGenerator = value - } -} - -// EndpointOptionLogger allows to set the logger. -func EndpointOptionLogger(value model.Logger) EndpointOption { - return func(es *Endpoint) { - es.Logger = value - } -} - // EndpointOptionTags allows to set tags to tag observations. func EndpointOptionTags(value ...string) EndpointOption { return func(es *Endpoint) { @@ -76,13 +46,6 @@ func EndpointOptionTags(value ...string) EndpointOption { } } -// EndpointOptionZeroTime allows to set the zero time. -func EndpointOptionZeroTime(value time.Time) EndpointOption { - return func(es *Endpoint) { - es.ZeroTime = value - } -} - // NewEndpoint creates a new network endpoint (i.e., a three tuple composed // of a network protocol, an IP address, and a port). // @@ -97,13 +60,10 @@ func EndpointOptionZeroTime(value time.Time) EndpointOption { func NewEndpoint( network EndpointNetwork, address EndpointAddress, options ...EndpointOption) *Endpoint { epnt := &Endpoint{ - Address: string(address), - Domain: "", - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - Network: string(network), - Tags: []string{}, - ZeroTime: time.Now(), + Address: string(address), + Domain: "", + Network: string(network), + Tags: []string{}, } for _, option := range options { option(epnt) diff --git a/internal/dslx/endpoint_test.go b/internal/dslx/endpoint_test.go index 61170f1fc0..cc62ecedea 100644 --- a/internal/dslx/endpoint_test.go +++ b/internal/dslx/endpoint_test.go @@ -3,25 +3,19 @@ package dslx import ( "sync/atomic" "testing" - "time" "github.com/google/go-cmp/cmp" - "github.com/ooni/probe-cli/v3/internal/model" ) func TestEndpoint(t *testing.T) { idGen := &atomic.Int64{} idGen.Add(42) - zt := time.Now() t.Run("Create new endpoint", func(t *testing.T) { testEndpoint := NewEndpoint( "network", "10.9.8.76", EndpointOptionDomain("www.example.com"), - EndpointOptionIDGenerator(idGen), - EndpointOptionLogger(model.DiscardLogger), - EndpointOptionZeroTime(zt), EndpointOptionTags("antani"), ) if testEndpoint.Network != "network" { @@ -33,15 +27,6 @@ func TestEndpoint(t *testing.T) { if testEndpoint.Domain != "www.example.com" { t.Fatalf("unexpected domain") } - if testEndpoint.IDGenerator != idGen { - t.Fatalf("unexpected IDGenerator") - } - if testEndpoint.Logger != model.DiscardLogger { - t.Fatalf("unexpected logger") - } - if testEndpoint.ZeroTime != zt { - t.Fatalf("unexpected zero time") - } if diff := cmp.Diff([]string{"antani"}, testEndpoint.Tags); diff != "" { t.Fatal(diff) } diff --git a/internal/dslx/http_test.go b/internal/dslx/http_test.go index 17d2419de9..e95ca14e0f 100644 --- a/internal/dslx/http_test.go +++ b/internal/dslx/http_test.go @@ -30,7 +30,8 @@ Test cases: */ func TestHTTPRequest(t *testing.T) { t.Run("Get httpRequestFunc with options", func(t *testing.T) { - f := HTTPRequest( + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPRequest(rt, HTTPRequestOptionAccept("text/html"), HTTPRequestOptionAcceptLanguage("de"), HTTPRequestOptionHost("host"), @@ -96,16 +97,15 @@ func TestHTTPRequest(t *testing.T) { t.Run("with EOF", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4:567", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "https", - Trace: trace, - Transport: eofTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4:567", + Network: "tcp", + Scheme: "https", + Trace: trace, + Transport: eofTransport, + } + httpRequest := &httpRequestFunc{ + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } - httpRequest := &httpRequestFunc{} res := httpRequest.Apply(context.Background(), &httpTransport) if res.Error != io.EOF { t.Fatal("not the error we expected") @@ -117,14 +117,11 @@ func TestHTTPRequest(t *testing.T) { t.Run("with invalid method", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4:567", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "https", - Trace: trace, - Transport: goodTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4:567", + Network: "tcp", + Scheme: "https", + Trace: trace, + Transport: goodTransport, } httpRequest := &httpRequestFunc{ Method: "€", @@ -140,16 +137,15 @@ func TestHTTPRequest(t *testing.T) { t.Run("with port-less address", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "https", - Trace: trace, - Transport: goodTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4", + Network: "tcp", + Scheme: "https", + Trace: trace, + Transport: goodTransport, + } + httpRequest := &httpRequestFunc{ + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } - httpRequest := &httpRequestFunc{} res := httpRequest.Apply(context.Background(), &httpTransport) if res.Error != nil { t.Fatal("expected error") @@ -193,16 +189,15 @@ func TestHTTPRequest(t *testing.T) { t.Run("with success (https)", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4:443", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "https", - Trace: trace, - Transport: goodTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4:443", + Network: "tcp", + Scheme: "https", + Trace: trace, + Transport: goodTransport, + } + httpRequest := &httpRequestFunc{ + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } - httpRequest := &httpRequestFunc{} res := httpRequest.Apply(context.Background(), &httpTransport) if res.Error != nil { t.Fatal("unexpected error") @@ -215,16 +210,15 @@ func TestHTTPRequest(t *testing.T) { t.Run("with success (http)", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4:80", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "http", - Trace: trace, - Transport: goodTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4:80", + Network: "tcp", + Scheme: "http", + Trace: trace, + Transport: goodTransport, + } + httpRequest := &httpRequestFunc{ + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } - httpRequest := &httpRequestFunc{} res := httpRequest.Apply(context.Background(), &httpTransport) if res.Error != nil { t.Fatal("unexpected error") @@ -237,21 +231,19 @@ func TestHTTPRequest(t *testing.T) { t.Run("with header options", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4:567", - Domain: "domain.com", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "https", - Trace: trace, - Transport: goodTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4:567", + Domain: "domain.com", + Network: "tcp", + Scheme: "https", + Trace: trace, + Transport: goodTransport, } httpRequest := &httpRequestFunc{ Accept: "text/html", AcceptLanguage: "de", Host: "host", Referer: "https://example.org", + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), URLPath: "/path/to/example", UserAgent: "Mozilla/5.0 Gecko/geckotrail Firefox/firefoxversion", } @@ -284,14 +276,16 @@ Test cases: */ func TestHTTPTCP(t *testing.T) { t.Run("Get httpTransportTCPFunc", func(t *testing.T) { - f := HTTPTransportTCP() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPTransportTCP(rt) if _, ok := f.(*httpTransportTCPFunc); !ok { t.Fatal("unexpected type") } }) t.Run("Get composed function: TCP with HTTP", func(t *testing.T) { - f := HTTPRequestOverTCP() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPRequestOverTCP(rt) if _, ok := f.(*compose2Func[*TCPConnection, *HTTPTransport, *HTTPResponse]); !ok { t.Fatal("unexpected type") } @@ -304,15 +298,14 @@ func TestHTTPTCP(t *testing.T) { trace := measurexlite.NewTrace(idGen.Add(1), zeroTime) address := "1.2.3.4:567" tcpConn := &TCPConnection{ - Address: address, - Conn: conn, - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Trace: trace, - ZeroTime: zeroTime, - } - f := httpTransportTCPFunc{} + Address: address, + Conn: conn, + Network: "tcp", + Trace: trace, + } + f := httpTransportTCPFunc{ + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + } res := f.Apply(context.Background(), tcpConn) if res.Error != nil { t.Fatalf("unexpected error: %s", res.Error) @@ -337,14 +330,16 @@ Test cases: */ func TestHTTPQUIC(t *testing.T) { t.Run("Get httpTransportQUICFunc", func(t *testing.T) { - f := HTTPTransportQUIC() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPTransportQUIC(rt) if _, ok := f.(*httpTransportQUICFunc); !ok { t.Fatal("unexpected type") } }) t.Run("Get composed function: QUIC with HTTP", func(t *testing.T) { - f := HTTPRequestOverQUIC() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPRequestOverQUIC(rt) if _, ok := f.(*compose2Func[*QUICConnection, *HTTPTransport, *HTTPResponse]); !ok { t.Fatal("unexpected type") } @@ -357,15 +352,14 @@ func TestHTTPQUIC(t *testing.T) { trace := measurexlite.NewTrace(idGen.Add(1), zeroTime) address := "1.2.3.4:567" quicConn := &QUICConnection{ - Address: address, - QUICConn: conn, - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "udp", - Trace: trace, - ZeroTime: zeroTime, - } - f := httpTransportQUICFunc{} + Address: address, + QUICConn: conn, + Network: "udp", + Trace: trace, + } + f := httpTransportQUICFunc{ + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + } res := f.Apply(context.Background(), quicConn) if res.Error != nil { t.Fatalf("unexpected error: %s", res.Error) @@ -390,14 +384,16 @@ Test cases: */ func TestHTTPTLS(t *testing.T) { t.Run("Get httpTransportTLSFunc", func(t *testing.T) { - f := HTTPTransportTLS() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPTransportTLS(rt) if _, ok := f.(*httpTransportTLSFunc); !ok { t.Fatal("unexpected type") } }) t.Run("Get composed function: TLS with HTTP", func(t *testing.T) { - f := HTTPRequestOverTLS() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPRequestOverTLS(rt) if _, ok := f.(*compose2Func[*TLSConnection, *HTTPTransport, *HTTPResponse]); !ok { t.Fatal("unexpected type") } @@ -410,15 +406,14 @@ func TestHTTPTLS(t *testing.T) { trace := measurexlite.NewTrace(idGen.Add(1), zeroTime) address := "1.2.3.4:567" tlsConn := &TLSConnection{ - Address: address, - Conn: conn, - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Trace: trace, - ZeroTime: zeroTime, - } - f := httpTransportTLSFunc{} + Address: address, + Conn: conn, + Network: "tcp", + Trace: trace, + } + f := httpTransportTLSFunc{ + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + } res := f.Apply(context.Background(), tlsConn) if res.Error != nil { t.Fatalf("unexpected error: %s", res.Error) diff --git a/internal/dslx/httpcore.go b/internal/dslx/httpcore.go index 2ffb4281ab..08ff04eb2b 100644 --- a/internal/dslx/httpcore.go +++ b/internal/dslx/httpcore.go @@ -10,7 +10,6 @@ import ( "net" "net/http" "net/url" - "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/logx" @@ -32,12 +31,6 @@ type HTTPTransport struct { // Domain is the OPTIONAL domain from which the address was resolved. Domain string - // IDGenerator is the MANDATORY ID generator. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY network used by the underlying conn. Network string @@ -52,9 +45,6 @@ type HTTPTransport struct { // Transport is the MANDATORY HTTP transport we're using. Transport model.HTTPTransport - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time } // HTTPRequestOption is an option you can pass to HTTPRequest. @@ -110,8 +100,8 @@ func HTTPRequestOptionUserAgent(value string) HTTPRequestOption { } // HTTPRequest issues an HTTP request using a transport and returns a response. -func HTTPRequest(options ...HTTPRequestOption) Func[*HTTPTransport, *Maybe[*HTTPResponse]] { - f := &httpRequestFunc{} +func HTTPRequest(rt Runtime, options ...HTTPRequestOption) Func[*HTTPTransport, *Maybe[*HTTPResponse]] { + f := &httpRequestFunc{Rt: rt} for _, option := range options { option(f) } @@ -135,6 +125,9 @@ type httpRequestFunc struct { // Referer is the OPTIONAL referer header. Referer string + // Rt is the MANDATORY runtime. + Rt Runtime + // URLPath is the OPTIONAL URL path. URLPath string @@ -162,7 +155,7 @@ func (f *httpRequestFunc) Apply( // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.Rt.Logger(), "[#%d] HTTPRequest %s with %s/%s host=%s", input.Trace.Index(), req.URL.String(), @@ -186,11 +179,8 @@ func (f *httpRequestFunc) Apply( HTTPRequest: req, // possibly nil HTTPResponse: resp, // possibly nil HTTPResponseBodySnapshot: body, // possibly nil - IDGenerator: input.IDGenerator, - Logger: input.Logger, Network: input.Network, Trace: input.Trace, - ZeroTime: input.ZeroTime, } return &Maybe[*HTTPResponse]{ @@ -262,7 +252,7 @@ func (f *httpRequestFunc) urlHost(input *HTTPTransport) string { } addr, port, err := net.SplitHostPort(input.Address) if err != nil { - input.Logger.Warnf("httpRequestFunc: cannot SplitHostPort for input.Address") + f.Rt.Logger().Warnf("httpRequestFunc: cannot SplitHostPort for input.Address") return input.Address } switch { @@ -369,19 +359,10 @@ type HTTPResponse struct { // HTTPResponseBodySnapshot is the response body or nil if Err != nil. HTTPResponseBodySnapshot []byte - // IDGenerator is the MANDATORY ID generator. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY network we're connected to. Network string // Trace is the MANDATORY trace we're using. The trace is drained // when you call the Observations method. Trace Trace - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time } diff --git a/internal/dslx/httpquic.go b/internal/dslx/httpquic.go index b2809bb2ca..6553f50d3f 100644 --- a/internal/dslx/httpquic.go +++ b/internal/dslx/httpquic.go @@ -11,24 +11,26 @@ import ( ) // HTTPRequestOverQUIC returns a Func that issues HTTP requests over QUIC. -func HTTPRequestOverQUIC(options ...HTTPRequestOption) Func[*QUICConnection, *Maybe[*HTTPResponse]] { - return Compose2(HTTPTransportQUIC(), HTTPRequest(options...)) +func HTTPRequestOverQUIC(rt Runtime, options ...HTTPRequestOption) Func[*QUICConnection, *Maybe[*HTTPResponse]] { + return Compose2(HTTPTransportQUIC(rt), HTTPRequest(rt, options...)) } // HTTPTransportQUIC converts a QUIC connection into an HTTP transport. -func HTTPTransportQUIC() Func[*QUICConnection, *Maybe[*HTTPTransport]] { - return &httpTransportQUICFunc{} +func HTTPTransportQUIC(rt Runtime) Func[*QUICConnection, *Maybe[*HTTPTransport]] { + return &httpTransportQUICFunc{rt} } // httpTransportQUICFunc is the function returned by HTTPTransportQUIC. -type httpTransportQUICFunc struct{} +type httpTransportQUICFunc struct { + rt Runtime +} // Apply implements Func. func (f *httpTransportQUICFunc) Apply( ctx context.Context, input *QUICConnection) *Maybe[*HTTPTransport] { // create transport httpTransport := netxlite.NewHTTP3Transport( - input.Logger, + f.rt.Logger(), netxlite.NewSingleUseQUICDialer(input.QUICConn), input.TLSConfig, ) @@ -36,14 +38,11 @@ func (f *httpTransportQUICFunc) Apply( state := &HTTPTransport{ Address: input.Address, Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, Network: input.Network, Scheme: "https", TLSNegotiatedProtocol: input.TLSState.NegotiatedProtocol, Trace: input.Trace, Transport: httpTransport, - ZeroTime: input.ZeroTime, } return &Maybe[*HTTPTransport]{ Error: nil, diff --git a/internal/dslx/httptcp.go b/internal/dslx/httptcp.go index 2571322130..b0d670cec7 100644 --- a/internal/dslx/httptcp.go +++ b/internal/dslx/httptcp.go @@ -11,17 +11,19 @@ import ( ) // HTTPRequestOverTCP returns a Func that issues HTTP requests over TCP. -func HTTPRequestOverTCP(options ...HTTPRequestOption) Func[*TCPConnection, *Maybe[*HTTPResponse]] { - return Compose2(HTTPTransportTCP(), HTTPRequest(options...)) +func HTTPRequestOverTCP(rt Runtime, options ...HTTPRequestOption) Func[*TCPConnection, *Maybe[*HTTPResponse]] { + return Compose2(HTTPTransportTCP(rt), HTTPRequest(rt, options...)) } // HTTPTransportTCP converts a TCP connection into an HTTP transport. -func HTTPTransportTCP() Func[*TCPConnection, *Maybe[*HTTPTransport]] { - return &httpTransportTCPFunc{} +func HTTPTransportTCP(rt Runtime) Func[*TCPConnection, *Maybe[*HTTPTransport]] { + return &httpTransportTCPFunc{rt} } // httpTransportTCPFunc is the function returned by HTTPTransportTCP -type httpTransportTCPFunc struct{} +type httpTransportTCPFunc struct { + rt Runtime +} // Apply implements Func func (f *httpTransportTCPFunc) Apply( @@ -30,21 +32,18 @@ func (f *httpTransportTCPFunc) Apply( // function, but we can probably avoid using it, given that this code is // not using tracing and does not care about those quirks. httpTransport := netxlite.NewHTTPTransport( - input.Logger, + f.rt.Logger(), netxlite.NewSingleUseDialer(input.Conn), netxlite.NewNullTLSDialer(), ) state := &HTTPTransport{ Address: input.Address, Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, Network: input.Network, Scheme: "http", TLSNegotiatedProtocol: "", Trace: input.Trace, Transport: httpTransport, - ZeroTime: input.ZeroTime, } return &Maybe[*HTTPTransport]{ Error: nil, diff --git a/internal/dslx/httptls.go b/internal/dslx/httptls.go index 9c1448ce0a..1c541afd1f 100644 --- a/internal/dslx/httptls.go +++ b/internal/dslx/httptls.go @@ -11,17 +11,19 @@ import ( ) // HTTPRequestOverTLS returns a Func that issues HTTP requests over TLS. -func HTTPRequestOverTLS(options ...HTTPRequestOption) Func[*TLSConnection, *Maybe[*HTTPResponse]] { - return Compose2(HTTPTransportTLS(), HTTPRequest(options...)) +func HTTPRequestOverTLS(rt Runtime, options ...HTTPRequestOption) Func[*TLSConnection, *Maybe[*HTTPResponse]] { + return Compose2(HTTPTransportTLS(rt), HTTPRequest(rt, options...)) } // HTTPTransportTLS converts a TLS connection into an HTTP transport. -func HTTPTransportTLS() Func[*TLSConnection, *Maybe[*HTTPTransport]] { - return &httpTransportTLSFunc{} +func HTTPTransportTLS(rt Runtime) Func[*TLSConnection, *Maybe[*HTTPTransport]] { + return &httpTransportTLSFunc{rt} } // httpTransportTLSFunc is the function returned by HTTPTransportTLS. -type httpTransportTLSFunc struct{} +type httpTransportTLSFunc struct { + rt Runtime +} // Apply implements Func. func (f *httpTransportTLSFunc) Apply( @@ -30,21 +32,18 @@ func (f *httpTransportTLSFunc) Apply( // function, but we can probably avoid using it, given that this code is // not using tracing and does not care about those quirks. httpTransport := netxlite.NewHTTPTransport( - input.Logger, + f.rt.Logger(), netxlite.NewNullDialer(), netxlite.NewSingleUseTLSDialer(input.Conn), ) state := &HTTPTransport{ Address: input.Address, Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, Network: input.Network, Scheme: "https", TLSNegotiatedProtocol: input.TLSState.NegotiatedProtocol, Trace: input.Trace, Transport: httpTransport, - ZeroTime: input.ZeroTime, } return &Maybe[*HTTPTransport]{ Error: nil, diff --git a/internal/dslx/integration_test.go b/internal/dslx/integration_test.go index e199a59cea..1eddf74457 100644 --- a/internal/dslx/integration_test.go +++ b/internal/dslx/integration_test.go @@ -4,7 +4,6 @@ import ( "context" "net/http" "net/http/httptest" - "sync/atomic" "testing" "time" @@ -38,19 +37,16 @@ func TestMakeSureWeCollectSpeedSamples(t *testing.T) { // create a measuring function f0 := Compose3( TCPConnect(rt), - HTTPTransportTCP(), - HTTPRequest(), + HTTPTransportTCP(rt), + HTTPRequest(rt), ) // create the endpoint to measure epnt := &Endpoint{ - Address: server.Listener.Addr().String(), - Domain: "", - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - Network: "tcp", - Tags: []string{}, - ZeroTime: time.Now(), + Address: server.Listener.Addr().String(), + Domain: "", + Network: "tcp", + Tags: []string{}, } // measure the endpoint diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index bc25b54240..c643584e4a 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -10,7 +10,6 @@ import ( "crypto/x509" "io" "net" - "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/logx" @@ -82,14 +81,14 @@ type quicHandshakeFunc struct { func (f *quicHandshakeFunc) Apply( ctx context.Context, input *Endpoint) *Maybe[*QUICConnection] { // create trace - trace := f.Rt.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.Rt.NewTrace(f.Rt.IDGenerator().Add(1), f.Rt.ZeroTime(), input.Tags...) // use defaults or user-configured overrides serverName := f.serverName(input) // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.Rt.Logger(), "[#%d] QUICHandshake with %s SNI=%s", trace.Index(), input.Address, @@ -100,7 +99,7 @@ func (f *quicHandshakeFunc) Apply( udpListener := netxlite.NewUDPListener() quicDialer := f.dialer if quicDialer == nil { - quicDialer = trace.NewQUICDialerWithoutResolver(udpListener, input.Logger) + quicDialer = trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) } config := &tls.Config{ NextProtos: []string{"h3"}, @@ -129,16 +128,13 @@ func (f *quicHandshakeFunc) Apply( ol.Stop(err) state := &QUICConnection{ - Address: input.Address, - QUICConn: quicConn, // possibly nil - Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, - Network: input.Network, - TLSConfig: config, - TLSState: tlsState, - Trace: trace, - ZeroTime: input.ZeroTime, + Address: input.Address, + QUICConn: quicConn, // possibly nil + Domain: input.Domain, + Network: input.Network, + TLSConfig: config, + TLSState: tlsState, + Trace: trace, } return &Maybe[*QUICConnection]{ @@ -163,7 +159,7 @@ func (f *quicHandshakeFunc) serverName(input *Endpoint) string { // Note: golang requires a ServerName and fails if it's empty. If the provided // ServerName is an IP address, however, golang WILL NOT emit any SNI extension // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - input.Logger.Warn("TLSHandshake: cannot determine which SNI to use") + f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") return "" } @@ -179,12 +175,6 @@ type QUICConnection struct { // Domain is the OPTIONAL domain we resolved. Domain string - // IDGenerator is the MANDATORY ID generator to use. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY network we tried to use when connecting. Network string @@ -197,9 +187,6 @@ type QUICConnection struct { // Trace is the MANDATORY trace we're using. Trace Trace - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time } type quicCloserConn struct { diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 13b7c24560..40c4923812 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "crypto/x509" "io" - "sync/atomic" "testing" "time" @@ -106,12 +105,9 @@ func TestQUICHandshake(t *testing.T) { ServerName: tt.sni, } endpoint := &Endpoint{ - Address: "1.2.3.4:567", - Network: "udp", - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - Tags: tt.tags, - ZeroTime: time.Time{}, + Address: "1.2.3.4:567", + Network: "udp", + Tags: tt.tags, } res := quicHandshake.Apply(context.Background(), endpoint) if res.Error != tt.expectErr { @@ -139,11 +135,8 @@ func TestQUICHandshake(t *testing.T) { t.Run("with nil dialer", func(t *testing.T) { quicHandshake := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), dialer: nil} endpoint := &Endpoint{ - Address: "1.2.3.4:567", - Network: "udp", - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - ZeroTime: time.Time{}, + Address: "1.2.3.4:567", + Network: "udp", } ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -171,7 +164,6 @@ func TestServerNameQUIC(t *testing.T) { sni := "sni" endpoint := &Endpoint{ Address: "example.com:123", - Logger: model.DiscardLogger, } f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni} serverName := f.serverName(endpoint) @@ -185,7 +177,6 @@ func TestServerNameQUIC(t *testing.T) { endpoint := &Endpoint{ Address: "example.com:123", Domain: domain, - Logger: model.DiscardLogger, } f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} serverName := f.serverName(endpoint) @@ -198,7 +189,6 @@ func TestServerNameQUIC(t *testing.T) { hostaddr := "example.com" endpoint := &Endpoint{ Address: hostaddr + ":123", - Logger: model.DiscardLogger, } f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} serverName := f.serverName(endpoint) @@ -211,7 +201,6 @@ func TestServerNameQUIC(t *testing.T) { ip := "1.1.1.1" endpoint := &Endpoint{ Address: ip, - Logger: model.DiscardLogger, } f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} serverName := f.serverName(endpoint) diff --git a/internal/dslx/tcp.go b/internal/dslx/tcp.go index fff120c09f..af5dbcff3c 100644 --- a/internal/dslx/tcp.go +++ b/internal/dslx/tcp.go @@ -7,7 +7,6 @@ package dslx import ( "context" "net" - "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/logx" @@ -32,11 +31,11 @@ func (f *tcpConnectFunc) Apply( ctx context.Context, input *Endpoint) *Maybe[*TCPConnection] { // create trace - trace := f.rt.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.rt.NewTrace(f.rt.IDGenerator().Add(1), f.rt.ZeroTime(), input.Tags...) // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.rt.Logger(), "[#%d] TCPConnect %s", trace.Index(), input.Address, @@ -48,7 +47,7 @@ func (f *tcpConnectFunc) Apply( defer cancel() // obtain the dialer to use - dialer := f.dialerOrDefault(trace, input.Logger) + dialer := f.dialerOrDefault(trace, f.rt.Logger()) // connect conn, err := dialer.DialContext(ctx, "tcp", input.Address) @@ -60,14 +59,11 @@ func (f *tcpConnectFunc) Apply( ol.Stop(err) state := &TCPConnection{ - Address: input.Address, - Conn: conn, // possibly nil - Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, - Network: input.Network, - Trace: trace, - ZeroTime: input.ZeroTime, + Address: input.Address, + Conn: conn, // possibly nil + Domain: input.Domain, + Network: input.Network, + Trace: trace, } return &Maybe[*TCPConnection]{ @@ -99,18 +95,9 @@ type TCPConnection struct { // Domain is the OPTIONAL domain from which we resolved the Address. Domain string - // IDGenerator is the MANDATORY ID generator. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY network we tried to use when connecting. Network string // Trace is the MANDATORY trace we're using. Trace Trace - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time } diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index 6a0bf8569b..1ec42ef88a 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -4,7 +4,6 @@ import ( "context" "io" "net" - "sync/atomic" "testing" "time" @@ -72,12 +71,9 @@ func TestTCPConnect(t *testing.T) { rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) tcpConnect := &tcpConnectFunc{tt.dialer, rt} endpoint := &Endpoint{ - Address: "1.2.3.4:567", - Network: "tcp", - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - Tags: tt.tags, - ZeroTime: time.Time{}, + Address: "1.2.3.4:567", + Network: "tcp", + Tags: tt.tags, } res := tcpConnect.Apply(context.Background(), endpoint) if res.Error != tt.expectErr { diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index 0b310d4dfc..0c13215075 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -9,7 +9,6 @@ import ( "crypto/tls" "crypto/x509" "net" - "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/logx" @@ -100,7 +99,7 @@ func (f *tlsHandshakeFunc) Apply( // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.Rt.Logger(), "[#%d] TLSHandshake with %s SNI=%s ALPN=%v", trace.Index(), input.Address, @@ -109,7 +108,7 @@ func (f *tlsHandshakeFunc) Apply( ) // obtain the handshaker for use - handshaker := f.handshakerOrDefault(trace, input.Logger) + handshaker := f.handshakerOrDefault(trace, f.Rt.Logger()) // setup config := &tls.Config{ @@ -132,15 +131,12 @@ func (f *tlsHandshakeFunc) Apply( ol.Stop(err) state := &TLSConnection{ - Address: input.Address, - Conn: conn, // possibly nil - Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, - Network: input.Network, - TLSState: netxlite.MaybeTLSConnectionState(conn), - Trace: trace, - ZeroTime: input.ZeroTime, + Address: input.Address, + Conn: conn, // possibly nil + Domain: input.Domain, + Network: input.Network, + TLSState: netxlite.MaybeTLSConnectionState(conn), + Trace: trace, } return &Maybe[*TLSConnection]{ @@ -174,7 +170,7 @@ func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string { // Note: golang requires a ServerName and fails if it's empty. If the provided // ServerName is an IP address, however, golang WILL NOT emit any SNI extension // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - input.Logger.Warn("TLSHandshake: cannot determine which SNI to use") + f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") return "" } @@ -197,12 +193,6 @@ type TLSConnection struct { // Domain is the OPTIONAL domain we resolved. Domain string - // IDGenerator is the MANDATORY ID generator to use. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY network we tried to use when connecting. Network string @@ -211,7 +201,4 @@ type TLSConnection struct { // Trace is the MANDATORY trace we're using. Trace Trace - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time } diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 40030f9187..2fd209661b 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -148,13 +148,10 @@ func TestTLSHandshake(t *testing.T) { address = "1.2.3.4:567" } tcpConn := TCPConnection{ - Address: address, - Conn: &tcpConn, - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Trace: trace, - ZeroTime: zeroTime, + Address: address, + Conn: &tcpConn, + Network: "tcp", + Trace: trace, } res := tlsHandshake.Apply(context.Background(), &tcpConn) if res.Error != tt.expectErr { @@ -185,7 +182,6 @@ func TestServerNameTLS(t *testing.T) { sni := "sni" tcpConn := TCPConnection{ Address: "example.com:123", - Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), @@ -201,7 +197,6 @@ func TestServerNameTLS(t *testing.T) { tcpConn := TCPConnection{ Address: "example.com:123", Domain: domain, - Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), @@ -215,7 +210,6 @@ func TestServerNameTLS(t *testing.T) { hostaddr := "example.com" tcpConn := TCPConnection{ Address: hostaddr + ":123", - Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), @@ -229,7 +223,6 @@ func TestServerNameTLS(t *testing.T) { ip := "1.1.1.1" tcpConn := TCPConnection{ Address: ip, - Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), diff --git a/internal/tutorial/dslx/chapter02/README.md b/internal/tutorial/dslx/chapter02/README.md index e9be23d453..359344e4c4 100644 --- a/internal/tutorial/dslx/chapter02/README.md +++ b/internal/tutorial/dslx/chapter02/README.md @@ -43,10 +43,7 @@ import ( "context" "errors" "net" - "sync/atomic" - "time" - "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/dslx" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/runtimex" @@ -135,7 +132,6 @@ of dslx pipelines a unique identifier). ```Go type Measurer struct { config Config - idGen atomic.Int64 } var _ model.ExperimentMeasurer = &Measurer{} @@ -180,15 +176,6 @@ So, this is where we will use `dslx` to implement the SNI blocking experiment. func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { ``` -### Define measurement parameters - -`sess` is the session of this measurement run. - -```Go - sess := args.Session - -``` - `measurement` contains metadata, the (required) input in form of the target SNI, and the nettest results (`TestKeys`). @@ -251,9 +238,6 @@ experiment's start time. ```Go dnsInput := dslx.NewDomainToResolve( dslx.DomainName(thaddrHost), - dslx.DNSLookupOptionIDGenerator(&m.idGen), - dslx.DNSLookupOptionLogger(sess.Logger()), - dslx.DNSLookupOptionZeroTime(measurement.MeasurementStartTimeSaved), ) ``` @@ -262,7 +246,7 @@ Next, we create a minimal runtime. This data structure helps us to manage open connections and close them when `rt.Close` is invoked. ```Go - rt := dslx.NewMinimalRuntime(log.Log, time.Now()) + rt := dslx.NewMinimalRuntime(args.Session.Logger(), args.Measurement.MeasurementStartTimeSaved) defer rt.Close() ``` @@ -333,9 +317,6 @@ the protocol, address, and port three-tuple.) dslx.EndpointNetwork("tcp"), dslx.EndpointPort(443), dslx.EndpointOptionDomain(m.config.TestHelperAddress), - dslx.EndpointOptionIDGenerator(&m.idGen), - dslx.EndpointOptionLogger(sess.Logger()), - dslx.EndpointOptionZeroTime(measurement.MeasurementStartTimeSaved), ) runtimex.Assert(len(endpoints) >= 1, "expected at least one endpoint here") endpoint := endpoints[0] diff --git a/internal/tutorial/dslx/chapter02/main.go b/internal/tutorial/dslx/chapter02/main.go index 50991a4012..ef0e3c62c4 100644 --- a/internal/tutorial/dslx/chapter02/main.go +++ b/internal/tutorial/dslx/chapter02/main.go @@ -44,10 +44,7 @@ import ( "context" "errors" "net" - "sync/atomic" - "time" - "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/dslx" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/runtimex" @@ -136,7 +133,6 @@ func (tk *Subresult) mergeObservations(obs []*dslx.Observations) { // ```Go type Measurer struct { config Config - idGen atomic.Int64 } var _ model.ExperimentMeasurer = &Measurer{} @@ -179,15 +175,6 @@ func (m *Measurer) GetSummaryKeys(measurement *model.Measurement) (interface{}, // // ```Go func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { - // ``` - // - // ### Define measurement parameters - // - // `sess` is the session of this measurement run. - // - // ```Go - sess := args.Session - // ``` // // `measurement` contains metadata, the (required) input in form of @@ -252,9 +239,6 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { // ```Go dnsInput := dslx.NewDomainToResolve( dslx.DomainName(thaddrHost), - dslx.DNSLookupOptionIDGenerator(&m.idGen), - dslx.DNSLookupOptionLogger(sess.Logger()), - dslx.DNSLookupOptionZeroTime(measurement.MeasurementStartTimeSaved), ) // ``` @@ -263,7 +247,7 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { // open connections and close them when `rt.Close` is invoked. // // ```Go - rt := dslx.NewMinimalRuntime(log.Log, time.Now()) + rt := dslx.NewMinimalRuntime(args.Session.Logger(), args.Measurement.MeasurementStartTimeSaved) defer rt.Close() // ``` @@ -334,9 +318,6 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { dslx.EndpointNetwork("tcp"), dslx.EndpointPort(443), dslx.EndpointOptionDomain(m.config.TestHelperAddress), - dslx.EndpointOptionIDGenerator(&m.idGen), - dslx.EndpointOptionLogger(sess.Logger()), - dslx.EndpointOptionZeroTime(measurement.MeasurementStartTimeSaved), ) runtimex.Assert(len(endpoints) >= 1, "expected at least one endpoint here") endpoint := endpoints[0] From a2ea6db9f221287ca52d4143a7a7c0102c5944c1 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 18 Oct 2023 09:43:13 +0200 Subject: [PATCH 05/10] refactor(dslx): prepare for improving testing Rather than using custom fields for testing, we can configure in the runtime a custom model.MeasuringNetwork. We're not doing this just to simplify the codebase, rather the underlying intent here is making sure we don't need to keep much state in each function, so we can refactor them to be pure functions wrapped by an adapter that produces the desired type. In turn, by doing that, we will be able to factor complexity around invoking functions and parsing their results. In turn, by doing that, we will be able to modify the signature of the functions and do the following: 1. allow the DSL model to include stages that take in input a Maybe value rather than a value, so we can observe failures more easily than we do now and we can write inline code to save into test keys; 2. allow the DSL model to much more easily be refactored to use channels, which in turn enables us to compose operations more naturally and increase the amount of overlapping (think, e.g., how this enables the possibility of waiting additional time for a DNS-over-UDP resolver to wait for late/duplicate replies). --- internal/dslx/runtimemeasurex.go | 25 +++++++++++++++++--- internal/dslx/runtimemeasurex_test.go | 24 +++++++++++++++++++ internal/dslx/runtimeminimal.go | 33 ++++++++++++++++++++------- internal/dslx/runtimeminimal_test.go | 12 ++++++++++ 4 files changed, 83 insertions(+), 11 deletions(-) create mode 100644 internal/dslx/runtimemeasurex_test.go diff --git a/internal/dslx/runtimemeasurex.go b/internal/dslx/runtimemeasurex.go index a075d085b5..a5b7fd31f9 100644 --- a/internal/dslx/runtimemeasurex.go +++ b/internal/dslx/runtimemeasurex.go @@ -5,23 +5,42 @@ import ( "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) +// RuntimeMeasurexLiteOption is an option for initializing a [*RuntimeMeasurexLite]. +type RuntimeMeasurexLiteOption func(rt *RuntimeMeasurexLite) + +// RuntimeMeasurexLiteOptionMeasuringNetwork allows to configure which [model.MeasuringNetwork] to use. +func RuntimeMeasurexLiteOptionMeasuringNetwork(netx model.MeasuringNetwork) RuntimeMeasurexLiteOption { + return func(rt *RuntimeMeasurexLite) { + rt.netx = netx + } +} + // NewRuntimeMeasurexLite creates a [Runtime] using [measurexlite] to collect [*Observations]. -func NewRuntimeMeasurexLite(logger model.Logger, zeroTime time.Time) *RuntimeMeasurexLite { - return &RuntimeMeasurexLite{ +func NewRuntimeMeasurexLite(logger model.Logger, zeroTime time.Time, options ...RuntimeMeasurexLiteOption) *RuntimeMeasurexLite { + rt := &RuntimeMeasurexLite{ MinimalRuntime: NewMinimalRuntime(logger, zeroTime), + netx: &netxlite.Netx{Underlying: nil}, // implies using the host's network + } + for _, option := range options { + option(rt) } + return rt } // RuntimeMeasurexLite uses [measurexlite] to collect [*Observations.] type RuntimeMeasurexLite struct { *MinimalRuntime + netx model.MeasuringNetwork } // NewTrace implements Runtime. func (p *RuntimeMeasurexLite) NewTrace(index int64, zeroTime time.Time, tags ...string) Trace { - return measurexlite.NewTrace(index, zeroTime, tags...) + trace := measurexlite.NewTrace(index, zeroTime, tags...) + trace.Netx = p.netx + return trace } var _ Runtime = &RuntimeMeasurexLite{} diff --git a/internal/dslx/runtimemeasurex_test.go b/internal/dslx/runtimemeasurex_test.go new file mode 100644 index 0000000000..1deb5a4547 --- /dev/null +++ b/internal/dslx/runtimemeasurex_test.go @@ -0,0 +1,24 @@ +package dslx + +import ( + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/measurexlite" + "github.com/ooni/probe-cli/v3/internal/mocks" + "github.com/ooni/probe-cli/v3/internal/model" +) + +func TestMeasurexLiteRuntime(t *testing.T) { + t.Run("we can configure a custom model.MeasuringNetwork", func(t *testing.T) { + netx := &mocks.MeasuringNetwork{} + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(netx)) + if rt.netx != netx { + t.Fatal("did not set the measuring network") + } + trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime()).(*measurexlite.Trace) + if trace.Netx != netx { + t.Fatal("did not set the measuring network") + } + }) +} diff --git a/internal/dslx/runtimeminimal.go b/internal/dslx/runtimeminimal.go index 7003c2ceec..522505ef24 100644 --- a/internal/dslx/runtimeminimal.go +++ b/internal/dslx/runtimeminimal.go @@ -10,17 +10,32 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) +// MinimalRuntimeOption is an option for configuring the [*MinimalRuntime]. +type MinimalRuntimeOption func(rt *MinimalRuntime) + +// MinimalRuntimeOptionMeasuringNetwork configures the [model.MeasuringNetwork] to use. +func MinimalRuntimeOptionMeasuringNetwork(netx model.MeasuringNetwork) MinimalRuntimeOption { + return func(rt *MinimalRuntime) { + rt.netx = netx + } +} + // NewMinimalRuntime creates a minimal [Runtime] implementation. // // This [Runtime] implementation does not collect any [*Observations]. -func NewMinimalRuntime(logger model.Logger, zeroTime time.Time) *MinimalRuntime { - return &MinimalRuntime{ +func NewMinimalRuntime(logger model.Logger, zeroTime time.Time, options ...MinimalRuntimeOption) *MinimalRuntime { + rt := &MinimalRuntime{ idg: &atomic.Int64{}, logger: logger, mu: sync.Mutex{}, + netx: &netxlite.Netx{Underlying: nil}, // implies using the host's network v: []io.Closer{}, zeroT: zeroTime, } + for _, option := range options { + option(rt) + } + return rt } var _ Runtime = &MinimalRuntime{} @@ -30,6 +45,7 @@ type MinimalRuntime struct { idg *atomic.Int64 logger model.Logger mu sync.Mutex + netx model.MeasuringNetwork v []io.Closer zeroT time.Time } @@ -74,11 +90,12 @@ func (p *MinimalRuntime) Close() error { // NewTrace implements Runtime. func (p *MinimalRuntime) NewTrace(index int64, zeroTime time.Time, tags ...string) Trace { - return &minimalTrace{idx: index, tags: tags, zt: zeroTime} + return &minimalTrace{idx: index, netx: p.netx, tags: tags, zt: zeroTime} } type minimalTrace struct { idx int64 + netx model.MeasuringNetwork tags []string zt time.Time } @@ -105,27 +122,27 @@ func (tx *minimalTrace) NetworkEvents() (out []*model.ArchivalNetworkEvent) { // NewDialerWithoutResolver implements Trace. func (tx *minimalTrace) NewDialerWithoutResolver(dl model.DebugLogger, wrappers ...model.DialerWrapper) model.Dialer { - return netxlite.NewDialerWithoutResolver(dl, wrappers...) + return tx.netx.NewDialerWithoutResolver(dl, wrappers...) } // NewParallelUDPResolver implements Trace. func (tx *minimalTrace) NewParallelUDPResolver(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { - return netxlite.NewParallelUDPResolver(logger, dialer, address) + return tx.netx.NewParallelUDPResolver(logger, dialer, address) } // NewQUICDialerWithoutResolver implements Trace. func (tx *minimalTrace) NewQUICDialerWithoutResolver(listener model.UDPListener, dl model.DebugLogger, wrappers ...model.QUICDialerWrapper) model.QUICDialer { - return netxlite.NewQUICDialerWithoutResolver(listener, dl, wrappers...) + return tx.netx.NewQUICDialerWithoutResolver(listener, dl, wrappers...) } // NewStdlibResolver implements Trace. func (tx *minimalTrace) NewStdlibResolver(logger model.DebugLogger) model.Resolver { - return netxlite.NewStdlibResolver(logger) + return tx.netx.NewStdlibResolver(logger) } // NewTLSHandshakerStdlib implements Trace. func (tx *minimalTrace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { - return netxlite.NewTLSHandshakerStdlib(dl) + return tx.netx.NewTLSHandshakerStdlib(dl) } // QUICHandshakes implements Trace. diff --git a/internal/dslx/runtimeminimal_test.go b/internal/dslx/runtimeminimal_test.go index 4699787fb9..f773ccf82d 100644 --- a/internal/dslx/runtimeminimal_test.go +++ b/internal/dslx/runtimeminimal_test.go @@ -237,4 +237,16 @@ func TestMinimalRuntime(t *testing.T) { } }) }) + + t.Run("we can use a custom model.MeasuringNetwork", func(t *testing.T) { + netx := &mocks.MeasuringNetwork{} + rt := NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(netx)) + if rt.netx != netx { + t.Fatal("did not set the measuring network") + } + trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime()).(*minimalTrace) + if trace.netx != netx { + t.Fatal("did not set the measuring network") + } + }) } From 6e43e82f27a8c2ab419816aabb5ab3da5518ac74 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 18 Oct 2023 10:24:48 +0200 Subject: [PATCH 06/10] refactor(dslx): use model.MeasuringNetwork for testing This diff modifies dslx functions to always use the MeasuringNetwork for testing rather than using specific func fields. By doing this, we open up the possibility of simplifying the state of each func, with the ultimate goal of making them pure functions. By making them pure functions, we make the code more manageable and easy to modify, which opens up for additional refactorings. --- internal/dslx/dns.go | 36 ++++++++------------ internal/dslx/dns_test.go | 67 +++++++++++++++++++++++++++++--------- internal/dslx/quic.go | 8 +---- internal/dslx/quic_test.go | 25 +++----------- internal/dslx/tcp.go | 17 ++-------- internal/dslx/tcp_test.go | 21 ++++-------- internal/dslx/tls.go | 15 +-------- internal/dslx/tls_test.go | 26 ++++----------- 8 files changed, 87 insertions(+), 128 deletions(-) diff --git a/internal/dslx/dns.go b/internal/dslx/dns.go index 9da9c9ed0f..62b86293f6 100644 --- a/internal/dslx/dns.go +++ b/internal/dslx/dns.go @@ -9,7 +9,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -73,13 +72,12 @@ type ResolvedAddresses struct { // DNSLookupGetaddrinfo returns a function that resolves a domain name to // IP addresses using libc's getaddrinfo function. func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { - return &dnsLookupGetaddrinfoFunc{nil, rt} + return &dnsLookupGetaddrinfoFunc{rt} } // dnsLookupGetaddrinfoFunc is the function returned by DNSLookupGetaddrinfo. type dnsLookupGetaddrinfoFunc struct { - resolver model.Resolver // for testing - rt Runtime + rt Runtime } // Apply implements Func. @@ -102,10 +100,8 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - resolver := f.resolver - if resolver == nil { - resolver = trace.NewStdlibResolver(f.rt.Logger()) - } + // create the resolver + resolver := trace.NewStdlibResolver(f.rt.Logger()) // lookup addrs, err := resolver.LookupHost(ctx, input.Domain) @@ -131,18 +127,16 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( // IP addresses using the given DNS-over-UDP resolver. func DNSLookupUDP(rt Runtime, resolver string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { return &dnsLookupUDPFunc{ - Resolver: resolver, - mockResolver: nil, - rt: rt, + Resolver: resolver, + rt: rt, } } // dnsLookupUDPFunc is the function returned by DNSLookupUDP. type dnsLookupUDPFunc struct { // Resolver is the MANDATORY endpointed of the resolver to use. - Resolver string - mockResolver model.Resolver // for testing - rt Runtime + Resolver string + rt Runtime } // Apply implements Func. @@ -166,14 +160,12 @@ func (f *dnsLookupUDPFunc) Apply( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - resolver := f.mockResolver - if resolver == nil { - resolver = trace.NewParallelUDPResolver( - f.rt.Logger(), - trace.NewDialerWithoutResolver(f.rt.Logger()), - f.Resolver, - ) - } + // create the resolver + resolver := trace.NewParallelUDPResolver( + f.rt.Logger(), + trace.NewDialerWithoutResolver(f.rt.Logger()), + f.Resolver, + ) // lookup addrs, err := resolver.LookupHost(ctx, input.Domain) diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index 15f08e155e..5dd6f79bd1 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -3,6 +3,7 @@ package dslx import ( "context" "errors" + "net" "sync/atomic" "testing" "time" @@ -84,10 +85,15 @@ func TestGetaddrinfo(t *testing.T) { t.Run("with lookup error", func(t *testing.T) { mockedErr := errors.New("mocked") f := dnsLookupGetaddrinfoFunc{ - resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, mockedErr - }}, - rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, mockedErr + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -106,10 +112,15 @@ func TestGetaddrinfo(t *testing.T) { t.Run("with success", func(t *testing.T) { f := dnsLookupGetaddrinfoFunc{ - resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{"93.184.216.34"}, nil - }}, - rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()), + rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"93.184.216.34"}, nil + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -171,10 +182,22 @@ func TestLookupUDP(t *testing.T) { mockedErr := errors.New("mocked") f := dnsLookupUDPFunc{ Resolver: "1.1.1.1:53", - mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, mockedErr - }}, - rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, endpoint string) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, mockedErr + }, + } + }, + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + panic("should not be called") + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -194,10 +217,22 @@ func TestLookupUDP(t *testing.T) { t.Run("with success", func(t *testing.T) { f := dnsLookupUDPFunc{ Resolver: "1.1.1.1:53", - mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{"93.184.216.34"}, nil - }}, - rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()), + rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"93.184.216.34"}, nil + }, + } + }, + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + panic("should not be called") + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index c643584e4a..3acf675ac9 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -13,7 +13,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/quic-go/quic-go" ) @@ -73,8 +72,6 @@ type quicHandshakeFunc struct { // ServerName is the ServerName to handshake for. ServerName string - - dialer model.QUICDialer // for testing } // Apply implements Func. @@ -97,10 +94,7 @@ func (f *quicHandshakeFunc) Apply( // setup udpListener := netxlite.NewUDPListener() - quicDialer := f.dialer - if quicDialer == nil { - quicDialer = trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) - } + quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) config := &tls.Config{ NextProtos: []string{"h3"}, InsecureSkipVerify: f.InsecureSkipVerify, diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 40c4923812..2d34954bae 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -98,10 +98,13 @@ func TestQUICHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer { + return tt.dialer + }, + })) quicHandshake := &quicHandshakeFunc{ Rt: rt, - dialer: tt.dialer, ServerName: tt.sni, } endpoint := &Endpoint{ @@ -131,24 +134,6 @@ func TestQUICHandshake(t *testing.T) { }) wasClosed = false } - - t.Run("with nil dialer", func(t *testing.T) { - quicHandshake := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), dialer: nil} - endpoint := &Endpoint{ - Address: "1.2.3.4:567", - Network: "udp", - } - ctx, cancel := context.WithCancel(context.Background()) - cancel() - res := quicHandshake.Apply(ctx, endpoint) - - if res.Error == nil { - t.Fatalf("expected an error here") - } - if res.State.QUICConn != nil { - t.Fatalf("unexpected conn: %s", res.State.QUICConn) - } - }) }) } diff --git a/internal/dslx/tcp.go b/internal/dslx/tcp.go index af5dbcff3c..eaa54c2d30 100644 --- a/internal/dslx/tcp.go +++ b/internal/dslx/tcp.go @@ -10,20 +10,18 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) // TCPConnect returns a function that establishes TCP connections. func TCPConnect(rt Runtime) Func[*Endpoint, *Maybe[*TCPConnection]] { - f := &tcpConnectFunc{nil, rt} + f := &tcpConnectFunc{rt} return f } // tcpConnectFunc is a function that establishes TCP connections. type tcpConnectFunc struct { - dialer model.Dialer // for testing - rt Runtime + rt Runtime } // Apply applies the function to its arguments. @@ -47,7 +45,7 @@ func (f *tcpConnectFunc) Apply( defer cancel() // obtain the dialer to use - dialer := f.dialerOrDefault(trace, f.rt.Logger()) + dialer := trace.NewDialerWithoutResolver(f.rt.Logger()) // connect conn, err := dialer.DialContext(ctx, "tcp", input.Address) @@ -74,15 +72,6 @@ func (f *tcpConnectFunc) Apply( } } -// dialerOrDefault is the function used to obtain a dialer -func (f *tcpConnectFunc) dialerOrDefault(trace Trace, logger model.Logger) model.Dialer { - dialer := f.dialer - if dialer == nil { - dialer = trace.NewDialerWithoutResolver(logger) - } - return dialer -} - // TCPConnection is an established TCP connection. If you initialize // manually, init at least the ones marked as MANDATORY. type TCPConnection struct { diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index 1ec42ef88a..8748b634bb 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -68,8 +67,12 @@ func TestTCPConnect(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) - tcpConnect := &tcpConnectFunc{tt.dialer, rt} + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return tt.dialer + }, + })) + tcpConnect := &tcpConnectFunc{rt} endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "tcp", @@ -99,15 +102,3 @@ func TestTCPConnect(t *testing.T) { } }) } - -// Make sure we get a valid dialer if no mocked dialer is configured -func TestDialerOrDefault(t *testing.T) { - f := &tcpConnectFunc{ - rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - dialer: nil, - } - dialer := f.dialerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) - if dialer == nil { - t.Fatal("expected non-nil dialer here") - } -} diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index 0c13215075..af67d59f61 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -12,7 +12,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -82,9 +81,6 @@ type tlsHandshakeFunc struct { // ServerName is the ServerName to handshake for. ServerName string - - // for testing - handshaker model.TLSHandshaker } // Apply implements Func. @@ -108,7 +104,7 @@ func (f *tlsHandshakeFunc) Apply( ) // obtain the handshaker for use - handshaker := f.handshakerOrDefault(trace, f.Rt.Logger()) + handshaker := trace.NewTLSHandshakerStdlib(f.Rt.Logger()) // setup config := &tls.Config{ @@ -147,15 +143,6 @@ func (f *tlsHandshakeFunc) Apply( } } -// handshakerOrDefault is the function used to obtain an handshaker -func (f *tlsHandshakeFunc) handshakerOrDefault(trace Trace, logger model.Logger) model.TLSHandshaker { - handshaker := f.handshaker - if handshaker == nil { - handshaker = trace.NewTLSHandshakerStdlib(logger) - } - return handshaker -} - func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string { if f.ServerName != "" { return f.ServerName diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 2fd209661b..4dcce7e0e3 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -133,16 +132,19 @@ func TestTLSHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + rt := NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewTLSHandshakerStdlib: func(logger model.DebugLogger) model.TLSHandshaker { + return tt.handshaker + }, + })) tlsHandshake := &tlsHandshakeFunc{ NextProto: tt.config.nextProtos, Rt: rt, ServerName: tt.config.sni, - handshaker: tt.handshaker, } idGen := &atomic.Int64{} zeroTime := time.Time{} - trace := measurexlite.NewTrace(idGen.Add(1), zeroTime) + trace := rt.NewTrace(idGen.Add(1), zeroTime) address := tt.config.address if address == "" { address = "1.2.3.4:567" @@ -233,19 +235,3 @@ func TestServerNameTLS(t *testing.T) { } }) } - -// Make sure we get a valid handshaker if no mocked handshaker is configured -func TestHandshakerOrDefault(t *testing.T) { - f := &tlsHandshakeFunc{ - InsecureSkipVerify: false, - NextProto: []string{}, - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - RootCAs: &x509.CertPool{}, - ServerName: "", - handshaker: nil, - } - handshaker := f.handshakerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) - if handshaker == nil { - t.Fatal("expected non-nil handshaker here") - } -} From 7c32b7ed2083cddbbe1127bc4c9a1cfe9baf137a Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 18 Oct 2023 10:53:15 +0200 Subject: [PATCH 07/10] refactor(dslx): start using pure functions Introduce an adapter type that converts a pure function into a Func and start using it whenever it's easy, rather than rolling out structs that implement the Func type when we actually don't have state. This change is an improvement because we're creating the necessary conditions from moving complexity out of the functions that actually do something, which should be simpler, and adapters, which should contain the same, equal logic for creating pipelines. The ultimate goal here is to be able to have stages that accept a Maybe[A] as input, to be able to add inline evaluators that set test keys. I want to reach this goal by making the necessary transformation at the adapters level rather than changing each of the Func prototype at once. --- internal/dslx/dns.go | 193 +++++++++++++++++--------------------- internal/dslx/dns_test.go | 51 ++++------ internal/dslx/fxcore.go | 8 ++ internal/dslx/tcp.go | 99 +++++++++---------- internal/dslx/tcp_test.go | 11 +-- 5 files changed, 155 insertions(+), 207 deletions(-) diff --git a/internal/dslx/dns.go b/internal/dslx/dns.go index 62b86293f6..88d3bd65ec 100644 --- a/internal/dslx/dns.go +++ b/internal/dslx/dns.go @@ -72,117 +72,92 @@ type ResolvedAddresses struct { // DNSLookupGetaddrinfo returns a function that resolves a domain name to // IP addresses using libc's getaddrinfo function. func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { - return &dnsLookupGetaddrinfoFunc{rt} -} - -// dnsLookupGetaddrinfoFunc is the function returned by DNSLookupGetaddrinfo. -type dnsLookupGetaddrinfoFunc struct { - rt Runtime -} - -// Apply implements Func. -func (f *dnsLookupGetaddrinfoFunc) Apply( - ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { - - // create trace - trace := f.rt.NewTrace(f.rt.IDGenerator().Add(1), f.rt.ZeroTime(), input.Tags...) - - // start the operation logger - ol := logx.NewOperationLogger( - f.rt.Logger(), - "[#%d] DNSLookup[getaddrinfo] %s", - trace.Index(), - input.Domain, - ) - - // setup - const timeout = 4 * time.Second - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - // create the resolver - resolver := trace.NewStdlibResolver(f.rt.Logger()) - - // lookup - addrs, err := resolver.LookupHost(ctx, input.Domain) - - // stop the operation logger - ol.Stop(err) - - state := &ResolvedAddresses{ - Addresses: addrs, // maybe empty - Domain: input.Domain, - Trace: trace, - } - - return &Maybe[*ResolvedAddresses]{ - Error: err, - Observations: maybeTraceToObservations(trace), - Operation: netxlite.ResolveOperation, - State: state, - } + return FuncAdapter[*DomainToResolve, *Maybe[*ResolvedAddresses]](func(ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { + // create trace + trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime(), input.Tags...) + + // start the operation logger + ol := logx.NewOperationLogger( + rt.Logger(), + "[#%d] DNSLookup[getaddrinfo] %s", + trace.Index(), + input.Domain, + ) + + // setup + const timeout = 4 * time.Second + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // create the resolver + resolver := trace.NewStdlibResolver(rt.Logger()) + + // lookup + addrs, err := resolver.LookupHost(ctx, input.Domain) + + // stop the operation logger + ol.Stop(err) + + state := &ResolvedAddresses{ + Addresses: addrs, // maybe empty + Domain: input.Domain, + Trace: trace, + } + + return &Maybe[*ResolvedAddresses]{ + Error: err, + Observations: maybeTraceToObservations(trace), + Operation: netxlite.ResolveOperation, + State: state, + } + }) } // DNSLookupUDP returns a function that resolves a domain name to // IP addresses using the given DNS-over-UDP resolver. -func DNSLookupUDP(rt Runtime, resolver string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { - return &dnsLookupUDPFunc{ - Resolver: resolver, - rt: rt, - } -} - -// dnsLookupUDPFunc is the function returned by DNSLookupUDP. -type dnsLookupUDPFunc struct { - // Resolver is the MANDATORY endpointed of the resolver to use. - Resolver string - rt Runtime -} - -// Apply implements Func. -func (f *dnsLookupUDPFunc) Apply( - ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { - - // create trace - trace := f.rt.NewTrace(f.rt.IDGenerator().Add(1), f.rt.ZeroTime(), input.Tags...) - - // start the operation logger - ol := logx.NewOperationLogger( - f.rt.Logger(), - "[#%d] DNSLookup[%s/udp] %s", - trace.Index(), - f.Resolver, - input.Domain, - ) - - // setup - const timeout = 4 * time.Second - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - // create the resolver - resolver := trace.NewParallelUDPResolver( - f.rt.Logger(), - trace.NewDialerWithoutResolver(f.rt.Logger()), - f.Resolver, - ) - - // lookup - addrs, err := resolver.LookupHost(ctx, input.Domain) - - // stop the operation logger - ol.Stop(err) - - state := &ResolvedAddresses{ - Addresses: addrs, // maybe empty - Domain: input.Domain, - Trace: trace, - } - - return &Maybe[*ResolvedAddresses]{ - Error: err, - Observations: maybeTraceToObservations(trace), - Operation: netxlite.ResolveOperation, - State: state, - } +func DNSLookupUDP(rt Runtime, endpoint string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { + return FuncAdapter[*DomainToResolve, *Maybe[*ResolvedAddresses]](func(ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { + // create trace + trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime(), input.Tags...) + + // start the operation logger + ol := logx.NewOperationLogger( + rt.Logger(), + "[#%d] DNSLookup[%s/udp] %s", + trace.Index(), + endpoint, + input.Domain, + ) + + // setup + const timeout = 4 * time.Second + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // create the resolver + resolver := trace.NewParallelUDPResolver( + rt.Logger(), + trace.NewDialerWithoutResolver(rt.Logger()), + endpoint, + ) + + // lookup + addrs, err := resolver.LookupHost(ctx, input.Domain) + + // stop the operation logger + ol.Stop(err) + + state := &ResolvedAddresses{ + Addresses: addrs, // maybe empty + Domain: input.Domain, + Trace: trace, + } + + return &Maybe[*ResolvedAddresses]{ + Error: err, + Observations: maybeTraceToObservations(trace), + Operation: netxlite.ResolveOperation, + State: state, + } + }) } diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index 5dd6f79bd1..2ef3b82d3f 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -54,13 +54,6 @@ Test cases: - with success */ func TestGetaddrinfo(t *testing.T) { - t.Run("Get dnsLookupGetaddrinfoFunc", func(t *testing.T) { - f := DNSLookupGetaddrinfo(NewMinimalRuntime(model.DiscardLogger, time.Now())) - if _, ok := f.(*dnsLookupGetaddrinfoFunc); !ok { - t.Fatal("unexpected type, want dnsLookupGetaddrinfoFunc") - } - }) - t.Run("Apply dnsLookupGetaddrinfoFunc", func(t *testing.T) { domain := &DomainToResolve{ Domain: "example.com", @@ -68,9 +61,9 @@ func TestGetaddrinfo(t *testing.T) { } t.Run("with nil resolver", func(t *testing.T) { - f := dnsLookupGetaddrinfoFunc{ - rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } + f := DNSLookupGetaddrinfo( + NewMinimalRuntime(model.DiscardLogger, time.Now()), + ) ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately cancel the lookup res := f.Apply(ctx, domain) @@ -84,8 +77,8 @@ func TestGetaddrinfo(t *testing.T) { t.Run("with lookup error", func(t *testing.T) { mockedErr := errors.New("mocked") - f := dnsLookupGetaddrinfoFunc{ - rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + f := DNSLookupGetaddrinfo( + NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver { return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { @@ -94,7 +87,7 @@ func TestGetaddrinfo(t *testing.T) { } }, })), - } + ) res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { t.Fatal("unexpected empty observations") @@ -111,8 +104,8 @@ func TestGetaddrinfo(t *testing.T) { }) t.Run("with success", func(t *testing.T) { - f := dnsLookupGetaddrinfoFunc{ - rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + f := DNSLookupGetaddrinfo( + NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver { return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { @@ -121,7 +114,7 @@ func TestGetaddrinfo(t *testing.T) { } }, })), - } + ) res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { t.Fatal("unexpected empty observations") @@ -151,14 +144,6 @@ Test cases: - with success */ func TestLookupUDP(t *testing.T) { - t.Run("Get dnsLookupUDPFunc", func(t *testing.T) { - rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) - f := DNSLookupUDP(rt, "1.1.1.1:53") - if _, ok := f.(*dnsLookupUDPFunc); !ok { - t.Fatal("unexpected type, want dnsLookupUDPFunc") - } - }) - t.Run("Apply dnsLookupGetaddrinfoFunc", func(t *testing.T) { domain := &DomainToResolve{ Domain: "example.com", @@ -166,7 +151,7 @@ func TestLookupUDP(t *testing.T) { } t.Run("with nil resolver", func(t *testing.T) { - f := dnsLookupUDPFunc{Resolver: "1.1.1.1:53", rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} + f := DNSLookupUDP(NewMinimalRuntime(model.DiscardLogger, time.Now()), "1.1.1.1:53") ctx, cancel := context.WithCancel(context.Background()) cancel() res := f.Apply(ctx, domain) @@ -180,9 +165,8 @@ func TestLookupUDP(t *testing.T) { t.Run("with lookup error", func(t *testing.T) { mockedErr := errors.New("mocked") - f := dnsLookupUDPFunc{ - Resolver: "1.1.1.1:53", - rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + f := DNSLookupUDP( + NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, endpoint string) model.Resolver { return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { @@ -198,7 +182,8 @@ func TestLookupUDP(t *testing.T) { } }, })), - } + "1.1.1.1:53", + ) res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { t.Fatal("unexpected empty observations") @@ -215,9 +200,8 @@ func TestLookupUDP(t *testing.T) { }) t.Run("with success", func(t *testing.T) { - f := dnsLookupUDPFunc{ - Resolver: "1.1.1.1:53", - rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + f := DNSLookupUDP( + NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { return &mocks.Resolver{ MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { @@ -233,7 +217,8 @@ func TestLookupUDP(t *testing.T) { } }, })), - } + "1.1.1.1:53", + ) res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { t.Fatal("unexpected empty observations") diff --git a/internal/dslx/fxcore.go b/internal/dslx/fxcore.go index 64f95bbc0d..329a8fc23c 100644 --- a/internal/dslx/fxcore.go +++ b/internal/dslx/fxcore.go @@ -17,6 +17,14 @@ type Func[A, B any] interface { Apply(ctx context.Context, a A) B } +// FuncAdapter adapts a func to be a Func. +type FuncAdapter[A, B any] func(ctx context.Context, a A) B + +// Apply implements Func. +func (fa FuncAdapter[A, B]) Apply(ctx context.Context, a A) B { + return fa(ctx, a) +} + // Maybe is the result of an operation implemented by this package // that may fail such as [TCPConnect] or [TLSHandshake]. type Maybe[State any] struct { diff --git a/internal/dslx/tcp.go b/internal/dslx/tcp.go index eaa54c2d30..d576b84fce 100644 --- a/internal/dslx/tcp.go +++ b/internal/dslx/tcp.go @@ -15,61 +15,50 @@ import ( // TCPConnect returns a function that establishes TCP connections. func TCPConnect(rt Runtime) Func[*Endpoint, *Maybe[*TCPConnection]] { - f := &tcpConnectFunc{rt} - return f -} - -// tcpConnectFunc is a function that establishes TCP connections. -type tcpConnectFunc struct { - rt Runtime -} - -// Apply applies the function to its arguments. -func (f *tcpConnectFunc) Apply( - ctx context.Context, input *Endpoint) *Maybe[*TCPConnection] { - - // create trace - trace := f.rt.NewTrace(f.rt.IDGenerator().Add(1), f.rt.ZeroTime(), input.Tags...) - - // start the operation logger - ol := logx.NewOperationLogger( - f.rt.Logger(), - "[#%d] TCPConnect %s", - trace.Index(), - input.Address, - ) - - // setup - const timeout = 15 * time.Second - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - // obtain the dialer to use - dialer := trace.NewDialerWithoutResolver(f.rt.Logger()) - - // connect - conn, err := dialer.DialContext(ctx, "tcp", input.Address) - - // possibly register established conn for late close - f.rt.MaybeTrackConn(conn) - - // stop the operation logger - ol.Stop(err) - - state := &TCPConnection{ - Address: input.Address, - Conn: conn, // possibly nil - Domain: input.Domain, - Network: input.Network, - Trace: trace, - } - - return &Maybe[*TCPConnection]{ - Error: err, - Observations: maybeTraceToObservations(trace), - Operation: netxlite.ConnectOperation, - State: state, - } + return FuncAdapter[*Endpoint, *Maybe[*TCPConnection]](func(ctx context.Context, input *Endpoint) *Maybe[*TCPConnection] { + // create trace + trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime(), input.Tags...) + + // start the operation logger + ol := logx.NewOperationLogger( + rt.Logger(), + "[#%d] TCPConnect %s", + trace.Index(), + input.Address, + ) + + // setup + const timeout = 15 * time.Second + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // obtain the dialer to use + dialer := trace.NewDialerWithoutResolver(rt.Logger()) + + // connect + conn, err := dialer.DialContext(ctx, "tcp", input.Address) + + // possibly register established conn for late close + rt.MaybeTrackConn(conn) + + // stop the operation logger + ol.Stop(err) + + state := &TCPConnection{ + Address: input.Address, + Conn: conn, // possibly nil + Domain: input.Domain, + Network: input.Network, + Trace: trace, + } + + return &Maybe[*TCPConnection]{ + Error: err, + Observations: maybeTraceToObservations(trace), + Operation: netxlite.ConnectOperation, + State: state, + } + }) } // TCPConnection is an established TCP connection. If you initialize diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index 8748b634bb..f5f1e28532 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -13,15 +13,6 @@ import ( ) func TestTCPConnect(t *testing.T) { - t.Run("Get tcpConnectFunc", func(t *testing.T) { - f := TCPConnect( - NewMinimalRuntime(model.DiscardLogger, time.Now()), - ) - if _, ok := f.(*tcpConnectFunc); !ok { - t.Fatal("unexpected type. Expected: tcpConnectFunc") - } - }) - t.Run("Apply tcpConnectFunc", func(t *testing.T) { wasClosed := false plainConn := &mocks.Conn{ @@ -72,7 +63,7 @@ func TestTCPConnect(t *testing.T) { return tt.dialer }, })) - tcpConnect := &tcpConnectFunc{rt} + tcpConnect := TCPConnect(rt) endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "tcp", From 57243bb101195247cd16c0d619c1eb673b6c9821 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 18 Oct 2023 11:41:31 +0200 Subject: [PATCH 08/10] refactor(dslx): reduce tlsHandshakeFunc state This diff reduces the state kept by the tlsHandshakeFunc struct so that we can apply a transformation similar to the one we applied for TCPConnect() and implement TLSHandshake() using a pure func. The overall objective is that of factoring away completixity to enable manipulating this code more easily. While there, let's note that the changes applied here mean that we can reuse this code for configuring tls.Config for the QUICHandshake. --- internal/dslx/tls.go | 110 ++++++++++++++--------------- internal/dslx/tls_test.go | 143 ++++++++++++++++---------------------- 2 files changed, 114 insertions(+), 139 deletions(-) diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index af67d59f61..ee75086b4d 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -12,75 +12,58 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) // TLSHandshakeOption is an option you can pass to TLSHandshake. -type TLSHandshakeOption func(*tlsHandshakeFunc) +type TLSHandshakeOption func(config *tls.Config) // TLSHandshakeOptionInsecureSkipVerify controls whether TLS verification is enabled. func TLSHandshakeOptionInsecureSkipVerify(value bool) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.InsecureSkipVerify = value + return func(config *tls.Config) { + config.InsecureSkipVerify = value } } // TLSHandshakeOptionNextProto allows to configure the ALPN protocols. func TLSHandshakeOptionNextProto(value []string) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.NextProto = value + return func(config *tls.Config) { + config.NextProtos = value } } // TLSHandshakeOptionRootCAs allows to configure custom root CAs. func TLSHandshakeOptionRootCAs(value *x509.CertPool) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.RootCAs = value + return func(config *tls.Config) { + config.RootCAs = value } } // TLSHandshakeOptionServerName allows to configure the SNI to use. func TLSHandshakeOptionServerName(value string) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.ServerName = value + return func(config *tls.Config) { + config.ServerName = value } } // TLSHandshake returns a function performing TSL handshakes. func TLSHandshake(rt Runtime, options ...TLSHandshakeOption) Func[ *TCPConnection, *Maybe[*TLSConnection]] { - // See https://github.com/ooni/probe/issues/2413 to understand - // why we're using nil to force netxlite to use the cached - // default Mozilla cert pool. f := &tlsHandshakeFunc{ - InsecureSkipVerify: false, - NextProto: []string{}, - RootCAs: nil, - Rt: rt, - ServerName: "", - } - for _, option := range options { - option(f) + Options: options, + Rt: rt, } return f } // tlsHandshakeFunc performs TLS handshakes. type tlsHandshakeFunc struct { - // InsecureSkipVerify allows to skip TLS verification. - InsecureSkipVerify bool - - // NextProto contains the ALPNs to negotiate. - NextProto []string + // Options contains the options. + Options []TLSHandshakeOption - // RootCAs contains the Root CAs to use. - RootCAs *x509.CertPool - - // Pool is the Pool that owns us. + // Rt is the runtime that owns us. Rt Runtime - - // ServerName is the ServerName to handshake for. - ServerName string } // Apply implements Func. @@ -89,9 +72,8 @@ func (f *tlsHandshakeFunc) Apply( // keep using the same trace trace := input.Trace - // use defaults or user-configured overrides - serverName := f.serverName(input) - nextProto := f.nextProto() + // create a suitable TLS configuration + config := tlsNewConfig(input.Address, []string{"h2", "http/1.1"}, input.Domain, f.Rt.Logger(), f.Options...) // start the operation logger ol := logx.NewOperationLogger( @@ -99,20 +81,14 @@ func (f *tlsHandshakeFunc) Apply( "[#%d] TLSHandshake with %s SNI=%s ALPN=%v", trace.Index(), input.Address, - serverName, - nextProto, + config.ServerName, + config.NextProtos, ) // obtain the handshaker for use handshaker := trace.NewTLSHandshakerStdlib(f.Rt.Logger()) // setup - config := &tls.Config{ - NextProtos: nextProto, - InsecureSkipVerify: f.InsecureSkipVerify, - RootCAs: f.RootCAs, - ServerName: serverName, - } const timeout = 10 * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -143,31 +119,51 @@ func (f *tlsHandshakeFunc) Apply( } } -func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string { - if f.ServerName != "" { - return f.ServerName +// tlsNewConfig is an utility function to create a new TLS config. +// +// Arguments: +// +// - address is the endpoint address (e.g., 1.1.1.1:443); +// +// - defaultALPN contains the default to be used for configuring ALPN; +// +// - domain is the possibly empty domain to use; +// +// - logger is the logger to use; +// +// - options contains options to modify the TLS handshake defaults. +func tlsNewConfig(address string, defaultALPN []string, domain string, logger model.Logger, options ...TLSHandshakeOption) *tls.Config { + // See https://github.com/ooni/probe/issues/2413 to understand + // why we're using nil to force netxlite to use the cached + // default Mozilla cert pool. + config := &tls.Config{ + NextProtos: append([]string{}, defaultALPN...), + InsecureSkipVerify: false, + RootCAs: nil, + ServerName: tlsServerName(address, domain, logger), } - if input.Domain != "" { - return input.Domain + for _, option := range options { + option(config) + } + return config +} + +// tlsServerName is an utility function to obtina the server name from a TCPConnection. +func tlsServerName(address, domain string, logger model.Logger) string { + if domain != "" { + return domain } - addr, _, err := net.SplitHostPort(input.Address) + addr, _, err := net.SplitHostPort(address) if err == nil { return addr } // Note: golang requires a ServerName and fails if it's empty. If the provided // ServerName is an IP address, however, golang WILL NOT emit any SNI extension // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") + logger.Warn("TLSHandshake: cannot determine which SNI to use") return "" } -func (f *tlsHandshakeFunc) nextProto() []string { - if len(f.NextProto) > 0 { - return f.NextProto - } - return []string{"h2", "http/1.1"} -} - // TLSConnection is an established TLS connection. If you initialize // manually, init at least the ones marked as MANDATORY. type TLSConnection struct { diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 4dcce7e0e3..36df4a79ca 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -10,51 +10,66 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) -/* -Test cases: -- Get tlsHandshakeFunc with options -- Apply tlsHandshakeFunc: - - with EOF - - with invalid address - - with success - - with sni - - with options -*/ -func TestTLSHandshake(t *testing.T) { - t.Run("Get tlsHandshakeFunc with options", func(t *testing.T) { +func TestTLSNewConfig(t *testing.T) { + t.Run("without options", func(t *testing.T) { + config := tlsNewConfig("1.1.1.1:443", []string{"h2", "http/1.1"}, "sni", model.DiscardLogger) + + if config.InsecureSkipVerify { + t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", false, config.InsecureSkipVerify) + } + if diff := cmp.Diff([]string{"h2", "http/1.1"}, config.NextProtos); diff != "" { + t.Fatal(diff) + } + if config.ServerName != "sni" { + t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "sni", config.ServerName) + } + if !config.RootCAs.Equal(nil) { + t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", nil, config.RootCAs) + } + }) + + t.Run("with options", func(t *testing.T) { certpool := x509.NewCertPool() certpool.AddCert(&x509.Certificate{}) - f := TLSHandshake( - NewMinimalRuntime(model.DiscardLogger, time.Now()), + config := tlsNewConfig( + "1.1.1.1:443", []string{"h2", "http/1.1"}, "sni", model.DiscardLogger, TLSHandshakeOptionInsecureSkipVerify(true), TLSHandshakeOptionNextProto([]string{"h2"}), - TLSHandshakeOptionServerName("sni"), + TLSHandshakeOptionServerName("example.domain"), TLSHandshakeOptionRootCAs(certpool), ) - var handshakeFunc *tlsHandshakeFunc - var ok bool - if handshakeFunc, ok = f.(*tlsHandshakeFunc); !ok { - t.Fatal("unexpected type. Expected: tlsHandshakeFunc") - } - if !handshakeFunc.InsecureSkipVerify { - t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", true, false) + + if !config.InsecureSkipVerify { + t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", true, config.InsecureSkipVerify) } - if len(handshakeFunc.NextProto) != 1 || handshakeFunc.NextProto[0] != "h2" { - t.Fatalf("unexpected %s, expected %v, got %v", "NextProto", []string{"h2"}, handshakeFunc.NextProto) + if diff := cmp.Diff([]string{"h2"}, config.NextProtos); diff != "" { + t.Fatal(diff) } - if handshakeFunc.ServerName != "sni" { - t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "sni", handshakeFunc.ServerName) + if config.ServerName != "example.domain" { + t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "example.domain", config.ServerName) } - if !handshakeFunc.RootCAs.Equal(certpool) { - t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", certpool, handshakeFunc.RootCAs) + if !config.RootCAs.Equal(certpool) { + t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", nil, config.RootCAs) } }) +} +/* +Test cases: +- Apply tlsHandshakeFunc: + - with EOF + - with invalid address + - with success + - with sni + - with options +*/ +func TestTLSHandshake(t *testing.T) { t.Run("Apply tlsHandshakeFunc", func(t *testing.T) { wasClosed := false @@ -137,11 +152,10 @@ func TestTLSHandshake(t *testing.T) { return tt.handshaker }, })) - tlsHandshake := &tlsHandshakeFunc{ - NextProto: tt.config.nextProtos, - Rt: rt, - ServerName: tt.config.sni, - } + tlsHandshake := TLSHandshake(rt, + TLSHandshakeOptionNextProto(tt.config.nextProtos), + TLSHandshakeOptionServerName(tt.config.sni), + ) idGen := &atomic.Int64{} zeroTime := time.Time{} trace := rt.NewTrace(idGen.Add(1), zeroTime) @@ -174,62 +188,27 @@ func TestTLSHandshake(t *testing.T) { /* Test cases: -- With input SNI -- With input domain -- With input host address -- With input IP address +- With domain +- With host address +- With IP address */ -func TestServerNameTLS(t *testing.T) { - t.Run("With input SNI", func(t *testing.T) { - sni := "sni" - tcpConn := TCPConnection{ - Address: "example.com:123", - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - ServerName: sni, - } - serverName := f.serverName(&tcpConn) - if serverName != sni { +func TestTLSServerName(t *testing.T) { + t.Run("With domain", func(t *testing.T) { + serverName := tlsServerName("example.com:123", "domain", model.DiscardLogger) + if serverName != "domain" { t.Fatalf("unexpected server name: %s", serverName) } }) - t.Run("With input domain", func(t *testing.T) { - domain := "domain" - tcpConn := TCPConnection{ - Address: "example.com:123", - Domain: domain, - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) - if serverName != domain { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - t.Run("With input host address", func(t *testing.T) { - hostaddr := "example.com" - tcpConn := TCPConnection{ - Address: hostaddr + ":123", - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) - if serverName != hostaddr { + + t.Run("With host address", func(t *testing.T) { + serverName := tlsServerName("1.1.1.1:443", "", model.DiscardLogger) + if serverName != "1.1.1.1" { t.Fatalf("unexpected server name: %s", serverName) } }) - t.Run("With input IP address", func(t *testing.T) { - ip := "1.1.1.1" - tcpConn := TCPConnection{ - Address: ip, - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) + + t.Run("With IP address", func(t *testing.T) { + serverName := tlsServerName("1.1.1.1", "", model.DiscardLogger) if serverName != "" { t.Fatalf("unexpected server name: %s", serverName) } From e1f5bc9b9902db584da4f36173987bbc708c3061 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 18 Oct 2023 11:47:26 +0200 Subject: [PATCH 09/10] refactor(dslx): rewrite QUICHandshake using TLSHandshakeOption This diff takes advantage of the fact that now the TLSHandshakeOption are independent of the tlsHandshakeFunc structure, so we can use the same options for configuring the QUIC handshake. --- internal/dslx/quic.go | 85 +++++--------------------------------- internal/dslx/quic_test.go | 80 +---------------------------------- 2 files changed, 12 insertions(+), 153 deletions(-) diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index 3acf675ac9..91d0b8eebc 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -7,9 +7,7 @@ package dslx import ( "context" "crypto/tls" - "crypto/x509" "io" - "net" "time" "github.com/ooni/probe-cli/v3/internal/logx" @@ -17,61 +15,23 @@ import ( "github.com/quic-go/quic-go" ) -// QUICHandshakeOption is an option you can pass to QUICHandshake. -type QUICHandshakeOption func(*quicHandshakeFunc) - -// QUICHandshakeOptionInsecureSkipVerify controls whether QUIC verification is enabled. -func QUICHandshakeOptionInsecureSkipVerify(value bool) QUICHandshakeOption { - return func(thf *quicHandshakeFunc) { - thf.InsecureSkipVerify = value - } -} - -// QUICHandshakeOptionRootCAs allows to configure custom root CAs. -func QUICHandshakeOptionRootCAs(value *x509.CertPool) QUICHandshakeOption { - return func(thf *quicHandshakeFunc) { - thf.RootCAs = value - } -} - -// QUICHandshakeOptionServerName allows to configure the SNI to use. -func QUICHandshakeOptionServerName(value string) QUICHandshakeOption { - return func(thf *quicHandshakeFunc) { - thf.ServerName = value - } -} - // QUICHandshake returns a function performing QUIC handshakes. -func QUICHandshake(rt Runtime, options ...QUICHandshakeOption) Func[ +func QUICHandshake(rt Runtime, options ...TLSHandshakeOption) Func[ *Endpoint, *Maybe[*QUICConnection]] { - // See https://github.com/ooni/probe/issues/2413 to understand - // why we're using nil to force netxlite to use the cached - // default Mozilla cert pool. f := &quicHandshakeFunc{ - InsecureSkipVerify: false, - RootCAs: nil, - Rt: rt, - ServerName: "", - } - for _, option := range options { - option(f) + Options: options, + Rt: rt, } return f } // quicHandshakeFunc performs QUIC handshakes. type quicHandshakeFunc struct { - // InsecureSkipVerify allows to skip TLS verification. - InsecureSkipVerify bool - - // RootCAs contains the Root CAs to use. - RootCAs *x509.CertPool + // Options contains the options. + Options []TLSHandshakeOption - // Rt is the Runtime that owns us. + // Rt is the runtime that owns us. Rt Runtime - - // ServerName is the ServerName to handshake for. - ServerName string } // Apply implements Func. @@ -80,27 +40,22 @@ func (f *quicHandshakeFunc) Apply( // create trace trace := f.Rt.NewTrace(f.Rt.IDGenerator().Add(1), f.Rt.ZeroTime(), input.Tags...) - // use defaults or user-configured overrides - serverName := f.serverName(input) + // create a suitable TLS configuration + config := tlsNewConfig(input.Address, []string{"h3"}, input.Domain, f.Rt.Logger(), f.Options...) // start the operation logger ol := logx.NewOperationLogger( f.Rt.Logger(), - "[#%d] QUICHandshake with %s SNI=%s", + "[#%d] QUICHandshake with %s SNI=%s ALPN=%v", trace.Index(), input.Address, - serverName, + config.ServerName, + config.NextProtos, ) // setup udpListener := netxlite.NewUDPListener() quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) - config := &tls.Config{ - NextProtos: []string{"h3"}, - InsecureSkipVerify: f.InsecureSkipVerify, - RootCAs: f.RootCAs, - ServerName: serverName, - } const timeout = 10 * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -139,24 +94,6 @@ func (f *quicHandshakeFunc) Apply( } } -func (f *quicHandshakeFunc) serverName(input *Endpoint) string { - if f.ServerName != "" { - return f.ServerName - } - if input.Domain != "" { - return input.Domain - } - addr, _, err := net.SplitHostPort(input.Address) - if err == nil { - return addr - } - // Note: golang requires a ServerName and fails if it's empty. If the provided - // ServerName is an IP address, however, golang WILL NOT emit any SNI extension - // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") - return "" -} - // QUICConnection is an established QUIC connection. If you initialize // manually, init at least the ones marked as MANDATORY. type QUICConnection struct { diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 2d34954bae..9512783a71 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -3,7 +3,6 @@ package dslx import ( "context" "crypto/tls" - "crypto/x509" "io" "testing" "time" @@ -16,28 +15,12 @@ import ( /* Test cases: -- Get quicHandshakeFunc with options - Apply quicHandshakeFunc: - with EOF - success - with sni */ func TestQUICHandshake(t *testing.T) { - t.Run("Get quicHandshakeFunc with options", func(t *testing.T) { - certpool := x509.NewCertPool() - certpool.AddCert(&x509.Certificate{}) - - f := QUICHandshake( - NewMinimalRuntime(model.DiscardLogger, time.Now()), - QUICHandshakeOptionInsecureSkipVerify(true), - QUICHandshakeOptionServerName("sni"), - QUICHandshakeOptionRootCAs(certpool), - ) - if _, ok := f.(*quicHandshakeFunc); !ok { - t.Fatal("unexpected type. Expected: quicHandshakeFunc") - } - }) - t.Run("Apply quicHandshakeFunc", func(t *testing.T) { wasClosed := false plainConn := &mocks.QUICEarlyConnection{ @@ -103,10 +86,7 @@ func TestQUICHandshake(t *testing.T) { return tt.dialer }, })) - quicHandshake := &quicHandshakeFunc{ - Rt: rt, - ServerName: tt.sni, - } + quicHandshake := QUICHandshake(rt, TLSHandshakeOptionServerName(tt.sni)) endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "udp", @@ -136,61 +116,3 @@ func TestQUICHandshake(t *testing.T) { } }) } - -/* -Test cases: -- With input SNI -- With input domain -- With input host address -- With input IP address -*/ -func TestServerNameQUIC(t *testing.T) { - t.Run("With input SNI", func(t *testing.T) { - sni := "sni" - endpoint := &Endpoint{ - Address: "example.com:123", - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni} - serverName := f.serverName(endpoint) - if serverName != sni { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - - t.Run("With input domain", func(t *testing.T) { - domain := "domain" - endpoint := &Endpoint{ - Address: "example.com:123", - Domain: domain, - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} - serverName := f.serverName(endpoint) - if serverName != domain { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - - t.Run("With input host address", func(t *testing.T) { - hostaddr := "example.com" - endpoint := &Endpoint{ - Address: hostaddr + ":123", - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} - serverName := f.serverName(endpoint) - if serverName != hostaddr { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - - t.Run("With input IP address", func(t *testing.T) { - ip := "1.1.1.1" - endpoint := &Endpoint{ - Address: ip, - } - f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} - serverName := f.serverName(endpoint) - if serverName != "" { - t.Fatalf("unexpected server name: %s", serverName) - } - }) -} From 88715b4099aeccc3df065d76d059a9cee4ed1d2d Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 25 Oct 2023 09:22:12 +0200 Subject: [PATCH 10/10] x --- internal/dslx/quic.go | 2 +- internal/dslx/quic_test.go | 14 ++++++++++++++ internal/dslx/tls.go | 2 +- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index 91d0b8eebc..327b6d4be6 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -30,7 +30,7 @@ type quicHandshakeFunc struct { // Options contains the options. Options []TLSHandshakeOption - // Rt is the runtime that owns us. + // Rt is the Runtime that owns us. Rt Runtime } diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 9512783a71..d8a8b066ac 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -3,6 +3,7 @@ package dslx import ( "context" "crypto/tls" + "crypto/x509" "io" "testing" "time" @@ -15,12 +16,25 @@ import ( /* Test cases: +- Get quicHandshakeFunc with options - Apply quicHandshakeFunc: - with EOF - success - with sni */ func TestQUICHandshake(t *testing.T) { + t.Run("Get quicHandshakeFunc with options", func(t *testing.T) { + certpool := x509.NewCertPool() + certpool.AddCert(&x509.Certificate{}) + + f := QUICHandshake( + NewMinimalRuntime(model.DiscardLogger, time.Now()), + ) + if _, ok := f.(*quicHandshakeFunc); !ok { + t.Fatal("unexpected type. Expected: quicHandshakeFunc") + } + }) + t.Run("Apply quicHandshakeFunc", func(t *testing.T) { wasClosed := false plainConn := &mocks.QUICEarlyConnection{ diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index ee75086b4d..6ed1a63cdc 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -62,7 +62,7 @@ type tlsHandshakeFunc struct { // Options contains the options. Options []TLSHandshakeOption - // Rt is the runtime that owns us. + // Rt is the Runtime that owns us. Rt Runtime }