diff --git a/internal/experiment/openvpn/endpoint.go b/internal/experiment/openvpn/endpoint.go index fad2438d8..42f4a35a0 100644 --- a/internal/experiment/openvpn/endpoint.go +++ b/internal/experiment/openvpn/endpoint.go @@ -228,3 +228,13 @@ func extractBase64Blob(val string) (string, error) { } return string(dec), nil } + +func isValidProtocol(s string) bool { + if strings.HasPrefix(s, "openvpn://") { + return true + } + if strings.HasPrefix(s, "openvpn+obfs4://") { + return true + } + return false +} diff --git a/internal/experiment/openvpn/endpoint_test.go b/internal/experiment/openvpn/endpoint_test.go index d0195dbb6..3ff61b39d 100644 --- a/internal/experiment/openvpn/endpoint_test.go +++ b/internal/experiment/openvpn/endpoint_test.go @@ -380,3 +380,21 @@ func Test_extractBase64Blob(t *testing.T) { } }) } + +func Test_IsValidProtocol(t *testing.T) { + t.Run("openvpn is valid", func(t *testing.T) { + if !isValidProtocol("openvpn://foobar.bar") { + t.Error("openvpn:// should be a valid protocol") + } + }) + t.Run("openvpn+obfs4 is valid", func(t *testing.T) { + if !isValidProtocol("openvpn+obfs4://foobar.bar") { + t.Error("openvpn+obfs4:// should be a valid protocol") + } + }) + t.Run("openvpn+other is not valid", func(t *testing.T) { + if isValidProtocol("openvpn+ss://foobar.bar") { + t.Error("openvpn+ss:// should not be a valid protocol") + } + }) +} diff --git a/internal/experiment/openvpn/openvpn.go b/internal/experiment/openvpn/openvpn.go index c302ed7ce..b97cc3f4a 100644 --- a/internal/experiment/openvpn/openvpn.go +++ b/internal/experiment/openvpn/openvpn.go @@ -4,8 +4,8 @@ package openvpn import ( "context" "errors" + "fmt" "strconv" - "strings" "time" "github.com/ooni/probe-cli/v3/internal/measurexlite" @@ -21,6 +21,10 @@ const ( openVPNProcol = "openvpn" ) +var ( + ErrBadAuth = errors.New("bad provider authentication") +) + // Config contains the experiment config. // // This contains all the settings that user can set to modify the behaviour @@ -107,14 +111,124 @@ var ( ErrInvalidInput = errors.New("invalid input") ) -func isValidProtocol(s string) bool { - if strings.HasPrefix(s, "openvpn://") { - return true +func parseEndpoint(m *model.Measurement) (*endpoint, error) { + if m.Input != "" { + if ok := isValidProtocol(string(m.Input)); !ok { + return nil, ErrInvalidInput + } + return newEndpointFromInputString(string(m.Input)) + } + // The current InputPolicy should ensure we have a hardcoded input, + // so this error should only be raised if by mistake we change the InputPolicy. + return nil, fmt.Errorf("%w: %s", ErrInvalidInput, "input is mandatory") +} + +// AuthMethod is the authentication method used by a provider. +type AuthMethod string + +var ( + AuthCertificate = AuthMethod("cert") + AuthUserPass = AuthMethod("userpass") +) + +var providerAuthentication = map[string]AuthMethod{ + "riseup": AuthCertificate, + "tunnelbear": AuthUserPass, + "surfshark": AuthUserPass, +} + +func hasCredentialsInOptions(cfg Config, method AuthMethod) bool { + switch method { + case AuthCertificate: + ok := cfg.SafeCA != "" && cfg.SafeCert != "" && cfg.SafeKey != "" + return ok + default: + return false } - if strings.HasPrefix(s, "openvpn+obfs4://") { - return true +} + +// MaybeGetCredentialsFromOptions overrides authentication info with what user provided in options. +// Each certificate/key can be encoded in base64 so that a single option can be safely represented as command line options. +// This function returns no error if there are no credentials in the passed options, only if failing to parse them. +func MaybeGetCredentialsFromOptions(cfg Config, opts *vpnconfig.OpenVPNOptions, method AuthMethod) (bool, error) { + if ok := hasCredentialsInOptions(cfg, method); !ok { + return false, nil + } + ca, err := extractBase64Blob(cfg.SafeCA) + if err != nil { + return false, err + } + opts.CA = []byte(ca) + + key, err := extractBase64Blob(cfg.SafeKey) + if err != nil { + return false, err + } + opts.Key = []byte(key) + + cert, err := extractBase64Blob(cfg.SafeCert) + if err != nil { + return false, err + } + opts.Cert = []byte(cert) + return true, nil +} + +func (m *Measurer) getCredentialsFromAPI( + ctx context.Context, + sess model.ExperimentSession, + provider string, + opts *vpnconfig.OpenVPNOptions) error { + // We expect the credentials from the API response to be encoded as the direct PEM serialization. + apiCreds, err := m.FetchProviderCredentials(ctx, sess, provider) + // TODO(ainghazal): validate credentials have the info we expect, certs are not expired etc. + if err != nil { + sess.Logger().Warnf("Error fetching credentials from API: %s", err.Error()) + return err } - return false + sess.Logger().Infof("Got credentials from provider: %s", provider) + + opts.CA = []byte(apiCreds.Config.CA) + opts.Cert = []byte(apiCreds.Config.Cert) + opts.Key = []byte(apiCreds.Config.Key) + return nil +} + +// GetCredentialsFromOptionsOrAPI attempts to find valid credentials for the given provider, either +// from the passed Options (cli, oonirun), or from a remote call to the OONI API endpoint. +func (m *Measurer) GetCredentialsFromOptionsOrAPI( + ctx context.Context, + sess model.ExperimentSession, + provider string) (*vpnconfig.OpenVPNOptions, error) { + + method, ok := providerAuthentication[provider] + if !ok { + return nil, fmt.Errorf("%w: provider auth unknown: %s", ErrInvalidInput, provider) + } + + // Empty options object to fill with credentials. + creds := &vpnconfig.OpenVPNOptions{} + + switch method { + case AuthCertificate: + ok, err := MaybeGetCredentialsFromOptions(m.config, creds, method) + if err != nil { + return nil, err + } + if ok { + return creds, nil + } + // No options passed, so let's get the credentials that inputbuilder should have cached + // for us after hitting the OONI API. + if err := m.getCredentialsFromAPI(ctx, sess, provider, creds); err != nil { + return nil, err + } + return creds, nil + + default: + return nil, fmt.Errorf("%w: method not implemented (%s)", ErrInvalidInput, method) + } + } // Run implements model.ExperimentMeasurer.Run. @@ -125,25 +239,9 @@ func (m Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { measurement := args.Measurement sess := args.Session - var endpoint *endpoint - - if measurement.Input != "" { - if ok := isValidProtocol(string(measurement.Input)); !ok { - return ErrInvalidInput - } - var err error - endpoint, err = newEndpointFromInputString(string(measurement.Input)) - if err != nil { - return err - } - } else { - // if input is null, we get one from the hardcoded list of inputs. - // TODO(ainghazal): can input be empty at this stage? - // InputPolicy should ensure we have a hardcoded input, - // so this is probably non-reachable code. Move the shuffling there. - sess.Logger().Info("No input given, picking one hardcoded endpoint at random") - endpoint = DefaultEndpoints.Shuffle()[0] - measurement.Input = model.MeasurementTarget(endpoint.AsInputURI()) + endpoint, err := parseEndpoint(measurement) + if err != nil { + return err } tk := NewTestKeys() @@ -174,65 +272,6 @@ func (m Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { return nil } -// getCredentialsFromOptionsOrAPI attempts to find valid credentials for the given provider, either -// from the passed Options (cli, oonirun), or from a remote call to the OONI API endpoint. -func (m *Measurer) getCredentialsFromOptionsOrAPI( - ctx context.Context, - sess model.ExperimentSession, - provider string) (*vpnconfig.OpenVPNOptions, error) { - - // TODO(ainghazal): Ideally, we need to know which authentication methods each provider uses, and this is - // information that the experiment could hardcode. Sticking to Certificate-based auth for riseupvpn. - - // get an empty options object to fill with credentials - creds := &vpnconfig.OpenVPNOptions{} - - cfg := m.config - - if cfg.SafeCA != "" && cfg.SafeCert != "" && cfg.SafeKey != "" { - // We override authentication info with what user provided in options. - // We expect the options to be encoded in base64 so that a single optin can be safely represented as command line options. - ca, err := extractBase64Blob(cfg.SafeCA) - if err != nil { - return nil, err - } - creds.CA = []byte(ca) - - key, err := extractBase64Blob(cfg.SafeKey) - if err != nil { - return nil, err - } - creds.Key = []byte(key) - - cert, err := extractBase64Blob(cfg.SafeCert) - if err != nil { - return nil, err - } - creds.Key = []byte(cert) - - // return options-based credentials - return creds, nil - } - - // No options passed, so let's get the credentials that inputbuilder should have cached - // for us after hitting the OONI API. - // We expect the credentials from the API response to be encoded as the direct PEM serialization. - apiCreds, err := m.fetchProviderCredentials(ctx, sess, provider) - // TODO(ainghazal): validate credentials have the info we expect, certs are not expired etc. - - if err != nil { - sess.Logger().Warnf("Error fetching credentials from API: %s", err.Error()) - return nil, err - } - sess.Logger().Infof("Got credentials from provider: %s", provider) - - creds.CA = []byte(apiCreds.Config.CA) - creds.Cert = []byte(apiCreds.Config.Cert) - creds.Key = []byte(apiCreds.Config.Key) - - return creds, nil -} - // connectAndHandshake dials a connection and attempts an OpenVPN handshake using that dialer. func (m *Measurer) connectAndHandshake(ctx context.Context, index int64, zeroTime time.Time, sess model.ExperimentSession, endpoint *endpoint) (*SingleConnection, error) { @@ -241,13 +280,13 @@ func (m *Measurer) connectAndHandshake(ctx context.Context, index int64, zeroTim // create a trace for the network dialer trace := measurexlite.NewTrace(index, zeroTime) - // TODO(ainghazal): can I pass tags to this tracer? dialer := trace.NewDialerWithoutResolver(logger) // create a vpn tun Device that attempts to dial and performs the handshake handshakeTracer := vpntracex.NewTracerWithTransactionID(zeroTime, index) - credentials, err := m.getCredentialsFromOptionsOrAPI(ctx, sess, endpoint.Provider) + // TODO -- move to outer function ------ + credentials, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, endpoint.Provider) if err != nil { return nil, err } @@ -256,6 +295,7 @@ func (m *Measurer) connectAndHandshake(ctx context.Context, index int64, zeroTim if err != nil { return nil, err } + // TODO -- move to outer function ------ tun, err := tunnel.Start(ctx, dialer, openvpnConfig) @@ -308,7 +348,7 @@ func (m *Measurer) connectAndHandshake(ctx context.Context, index int64, zeroTim } // TODO: get cached from session instead of fetching every time -func (m *Measurer) fetchProviderCredentials( +func (m *Measurer) FetchProviderCredentials( ctx context.Context, sess model.ExperimentSession, provider string) (*model.OOAPIVPNProviderConfig, error) { diff --git a/internal/experiment/openvpn/openvpn_test.go b/internal/experiment/openvpn/openvpn_test.go index 7e7e05229..aa984f6bc 100644 --- a/internal/experiment/openvpn/openvpn_test.go +++ b/internal/experiment/openvpn/openvpn_test.go @@ -2,16 +2,17 @@ package openvpn_test import ( "context" + "errors" "fmt" "testing" "time" "github.com/google/go-cmp/cmp" + vpnconfig "github.com/ooni/minivpn/pkg/config" + vpntracex "github.com/ooni/minivpn/pkg/tracex" "github.com/ooni/probe-cli/v3/internal/experiment/openvpn" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" - - vpntracex "github.com/ooni/minivpn/pkg/tracex" ) func makeMockSession() *mocks.Session { @@ -66,7 +67,200 @@ func TestNewTestKeys(t *testing.T) { } } -// TODO refactoring tests ----------------------------------------------- +func TestMaybeGetCredentialsFromOptions(t *testing.T) { + t.Run("cert auth returns false if cert, key and ca are not all provided", func(t *testing.T) { + cfg := openvpn.Config{ + SafeCA: "base64:Zm9v", + SafeCert: "base64:Zm9v", + } + ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, &vpnconfig.OpenVPNOptions{}, openvpn.AuthCertificate) + if err != nil { + t.Fatal("should not raise error") + } + if ok { + t.Fatal("expected false") + } + }) + t.Run("cert auth returns ok if cert, key and ca are all provided", func(t *testing.T) { + cfg := openvpn.Config{ + SafeCA: "base64:Zm9v", + SafeCert: "base64:Zm9v", + SafeKey: "base64:Zm9v", + } + opts := &vpnconfig.OpenVPNOptions{} + ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, opts, openvpn.AuthCertificate) + if err != nil { + t.Fatalf("expected err=nil, got %v", err) + } + if !ok { + t.Fatal("expected true") + } + if diff := cmp.Diff(opts.CA, []byte("foo")); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(opts.Cert, []byte("foo")); diff != "" { + t.Fatal(diff) + } + if diff := cmp.Diff(opts.Key, []byte("foo")); diff != "" { + t.Fatal(diff) + } + }) + t.Run("cert auth returns false and error if CA base64 is bad blob", func(t *testing.T) { + cfg := openvpn.Config{ + SafeCA: "base64:Zm9vaaa", + SafeCert: "base64:Zm9v", + SafeKey: "base64:Zm9v", + } + opts := &vpnconfig.OpenVPNOptions{} + ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, opts, openvpn.AuthCertificate) + if ok { + t.Fatal("expected false") + } + if !errors.Is(err, openvpn.ErrBadBase64Blob) { + t.Fatalf("expected err=ErrBase64Blob, got %v", err) + } + }) + t.Run("cert auth returns false and error if key base64 is bad blob", func(t *testing.T) { + cfg := openvpn.Config{ + SafeCA: "base64:Zm9v", + SafeCert: "base64:Zm9v", + SafeKey: "base64:Zm9vaaa", + } + opts := &vpnconfig.OpenVPNOptions{} + ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, opts, openvpn.AuthCertificate) + if ok { + t.Fatal("expected false") + } + if !errors.Is(err, openvpn.ErrBadBase64Blob) { + t.Fatalf("expected err=ErrBase64Blob, got %v", err) + } + }) + t.Run("cert auth returns false and error if cert base64 is bad blob", func(t *testing.T) { + cfg := openvpn.Config{ + SafeCA: "base64:Zm9v", + SafeCert: "base64:Zm9vaaa", + SafeKey: "base64:Zm9v", + } + opts := &vpnconfig.OpenVPNOptions{} + ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, opts, openvpn.AuthCertificate) + if ok { + t.Fatal("expected false") + } + if !errors.Is(err, openvpn.ErrBadBase64Blob) { + t.Fatalf("expected err=ErrBase64Blob, got %v", err) + } + }) + t.Run("userpass auth returns error, not yet implemented", func(t *testing.T) { + cfg := openvpn.Config{} + ok, err := openvpn.MaybeGetCredentialsFromOptions(cfg, &vpnconfig.OpenVPNOptions{}, openvpn.AuthUserPass) + if ok { + t.Fatal("expected false") + } + if err != nil { + t.Fatalf("expected err=nil, got %v", err) + } + }) + +} + +func TestGetCredentialsFromOptionsOrAPI(t *testing.T) { + t.Run("non-registered provider raises error", func(t *testing.T) { + m := openvpn.NewExperimentMeasurer(openvpn.Config{}, "openvpn").(openvpn.Measurer) + ctx := context.Background() + sess := makeMockSession() + opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "nsa") + if !errors.Is(err, openvpn.ErrInvalidInput) { + t.Fatalf("expected err=ErrInvalidInput, got %v", err) + } + if opts != nil { + t.Fatal("expected opts=nil") + } + }) + t.Run("providers with userpass auth method raise error, not yet implemented", func(t *testing.T) { + m := openvpn.NewExperimentMeasurer(openvpn.Config{}, "openvpn").(openvpn.Measurer) + ctx := context.Background() + sess := makeMockSession() + opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "tunnelbear") + if !errors.Is(err, openvpn.ErrInvalidInput) { + t.Fatalf("expected err=ErrInvalidInput, got %v", err) + } + if opts != nil { + t.Fatal("expected opts=nil") + } + }) + t.Run("known cert auth provider and creds in options is ok", func(t *testing.T) { + config := openvpn.Config{ + SafeCA: "base64:Zm9v", + SafeCert: "base64:Zm9v", + SafeKey: "base64:Zm9v", + } + m := openvpn.NewExperimentMeasurer(config, "openvpn").(openvpn.Measurer) + ctx := context.Background() + sess := makeMockSession() + opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "riseup") + if err != nil { + t.Fatalf("expected err=nil, got %v", err) + } + if opts == nil { + t.Fatal("expected non-nil options") + } + }) + t.Run("known cert auth provider and bad creds in options returns error", func(t *testing.T) { + config := openvpn.Config{ + SafeCA: "base64:Zm9v", + SafeCert: "base64:Zm9v", + SafeKey: "base64:Zm9vaaa", + } + m := openvpn.NewExperimentMeasurer(config, "openvpn").(openvpn.Measurer) + ctx := context.Background() + sess := makeMockSession() + opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "riseup") + if !errors.Is(err, openvpn.ErrBadBase64Blob) { + t.Fatalf("expected err=ErrBadBase64, got %v", err) + } + if opts != nil { + t.Fatal("expected nil opts") + } + }) + t.Run("known cert auth provider with null options hits the api", func(t *testing.T) { + config := openvpn.Config{} + m := openvpn.NewExperimentMeasurer(config, "openvpn").(openvpn.Measurer) + ctx := context.Background() + sess := makeMockSession() + opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "riseup") + if err != nil { + t.Fatalf("expected err=nil, got %v", err) + } + if opts == nil { + t.Fatalf("expected not-nil options, got %v", opts) + } + }) + t.Run("known cert auth provider with null options hits the api and raises error if api fails", func(t *testing.T) { + config := openvpn.Config{} + m := openvpn.NewExperimentMeasurer(config, "openvpn").(openvpn.Measurer) + ctx := context.Background() + + someError := errors.New("some error") + sess := makeMockSession() + sess.MockFetchOpenVPNConfig = func(context.Context, string, string) (*model.OOAPIVPNProviderConfig, error) { + return nil, someError + } + + opts, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, "riseup") + if !errors.Is(err, someError) { + t.Fatalf("expected err=someError, got %v", err) + } + if opts != nil { + t.Fatalf("expected nil options, got %v", opts) + } + }) + /* + sess.MockFetchOpenVPNConfig = func(context.Context, string, string) (*model.OOAPIVPNProviderConfig, error) { + return nil, someError + } + */ + +} func TestAddConnectionTestKeys(t *testing.T) { t.Run("append connection result to empty keys", func(t *testing.T) { @@ -150,11 +344,29 @@ func TestAllConnectionsSuccessful(t *testing.T) { }) } +func TestBadInputFailure(t *testing.T) { + m := openvpn.NewExperimentMeasurer(openvpn.Config{}, "openvpn") + ctx := context.Background() + sess := makeMockSession() + callbacks := model.NewPrinterCallbacks(sess.Logger()) + measurement := new(model.Measurement) + measurement.Input = "openvpn://badprovider/?address=aa" + args := &model.ExperimentArgs{ + Callbacks: callbacks, + Measurement: measurement, + Session: sess, + } + err := m.Run(ctx, args) + if !errors.Is(err, openvpn.ErrInvalidInput) { + t.Fatalf("expected ErrInvalidInput, got %v", err) + } +} + func TestVPNInput(t *testing.T) { if testing.Short() { t.Skip("skip test in short mode") } - // TODO -- do a real test + // TODO -- do a real test, get credentials etc. } func TestSuccess(t *testing.T) { @@ -176,22 +388,37 @@ func TestSuccess(t *testing.T) { //} } -// TODO -- test incorrect certs failure. -func TestBadInputFailure(t *testing.T) { - m := openvpn.NewExperimentMeasurer(openvpn.Config{}, "openvpn") - ctx := context.Background() - sess := makeMockSession() - callbacks := model.NewPrinterCallbacks(sess.Logger()) - args := &model.ExperimentArgs{ - Callbacks: callbacks, - Measurement: new(model.Measurement), - Session: sess, - } - fmt.Println(m, ctx, args) - /* - err := m.Run(ctx, args) - if !errors.Is(err, example.ErrFailure) { - t.Fatal("expected an error here") +func TestMeasurer_FetchProviderCredentials(t *testing.T) { + t.Run("Measurer.FetchProviderCredentials calls method in session", func(t *testing.T) { + m := openvpn.NewExperimentMeasurer( + openvpn.Config{}, + "openvpn").(openvpn.Measurer) + + sess := makeMockSession() + _, err := m.FetchProviderCredentials( + context.Background(), + sess, "riseup") + if err != nil { + t.Fatal("expected no error") } - */ + }) + t.Run("Measurer.FetchProviderCredentials raises error if API calls fail", func(t *testing.T) { + someError := errors.New("unexpected") + + m := openvpn.NewExperimentMeasurer( + openvpn.Config{}, + "openvpn").(openvpn.Measurer) + + sess := makeMockSession() + sess.MockFetchOpenVPNConfig = func(context.Context, string, string) (*model.OOAPIVPNProviderConfig, error) { + return nil, someError + } + _, err := m.FetchProviderCredentials( + context.Background(), + sess, "riseup") + if !errors.Is(err, someError) { + t.Fatalf("expected error %v, got %v", someError, err) + } + }) + }