Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ainghazal committed Apr 2, 2024
1 parent 19b3fa9 commit 2fc7f09
Show file tree
Hide file tree
Showing 4 changed files with 404 additions and 109 deletions.
10 changes: 10 additions & 0 deletions internal/experiment/openvpn/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,13 @@ func extractBase64Blob(val string) (string, error) {
}
return string(dec), nil
}

func isValidProtocol(s string) bool {
if strings.HasPrefix(s, "openvpn://") {
return true
}
if strings.HasPrefix(s, "openvpn+obfs4://") {
return true
}
return false
}
18 changes: 18 additions & 0 deletions internal/experiment/openvpn/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,3 +380,21 @@ func Test_extractBase64Blob(t *testing.T) {
}
})
}

func Test_IsValidProtocol(t *testing.T) {
t.Run("openvpn is valid", func(t *testing.T) {
if !isValidProtocol("openvpn://foobar.bar") {
t.Error("openvpn:// should be a valid protocol")
}
})
t.Run("openvpn+obfs4 is valid", func(t *testing.T) {
if !isValidProtocol("openvpn+obfs4://foobar.bar") {
t.Error("openvpn+obfs4:// should be a valid protocol")
}
})
t.Run("openvpn+other is not valid", func(t *testing.T) {
if isValidProtocol("openvpn+ss://foobar.bar") {
t.Error("openvpn+ss:// should not be a valid protocol")
}
})
}
216 changes: 128 additions & 88 deletions internal/experiment/openvpn/openvpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ package openvpn
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"

"github.com/ooni/probe-cli/v3/internal/measurexlite"
Expand All @@ -21,6 +21,10 @@ const (
openVPNProcol = "openvpn"
)

var (
ErrBadAuth = errors.New("bad provider authentication")
)

// Config contains the experiment config.
//
// This contains all the settings that user can set to modify the behaviour
Expand Down Expand Up @@ -107,14 +111,124 @@ var (
ErrInvalidInput = errors.New("invalid input")
)

func isValidProtocol(s string) bool {
if strings.HasPrefix(s, "openvpn://") {
return true
func parseEndpoint(m *model.Measurement) (*endpoint, error) {
if m.Input != "" {
if ok := isValidProtocol(string(m.Input)); !ok {
return nil, ErrInvalidInput
}
return newEndpointFromInputString(string(m.Input))
}
// The current InputPolicy should ensure we have a hardcoded input,
// so this error should only be raised if by mistake we change the InputPolicy.
return nil, fmt.Errorf("%w: %s", ErrInvalidInput, "input is mandatory")
}

// AuthMethod is the authentication method used by a provider.
type AuthMethod string

var (
AuthCertificate = AuthMethod("cert")
AuthUserPass = AuthMethod("userpass")
)

var providerAuthentication = map[string]AuthMethod{
"riseup": AuthCertificate,
"tunnelbear": AuthUserPass,
"surfshark": AuthUserPass,
}

func hasCredentialsInOptions(cfg Config, method AuthMethod) bool {
switch method {
case AuthCertificate:
ok := cfg.SafeCA != "" && cfg.SafeCert != "" && cfg.SafeKey != ""
return ok
default:
return false
}
if strings.HasPrefix(s, "openvpn+obfs4://") {
return true
}

// MaybeGetCredentialsFromOptions overrides authentication info with what user provided in options.
// Each certificate/key can be encoded in base64 so that a single option can be safely represented as command line options.
// This function returns no error if there are no credentials in the passed options, only if failing to parse them.
func MaybeGetCredentialsFromOptions(cfg Config, opts *vpnconfig.OpenVPNOptions, method AuthMethod) (bool, error) {
if ok := hasCredentialsInOptions(cfg, method); !ok {
return false, nil
}
ca, err := extractBase64Blob(cfg.SafeCA)
if err != nil {
return false, err
}
opts.CA = []byte(ca)

key, err := extractBase64Blob(cfg.SafeKey)
if err != nil {
return false, err
}
opts.Key = []byte(key)

cert, err := extractBase64Blob(cfg.SafeCert)
if err != nil {
return false, err
}
opts.Cert = []byte(cert)
return true, nil
}

func (m *Measurer) getCredentialsFromAPI(
ctx context.Context,
sess model.ExperimentSession,
provider string,
opts *vpnconfig.OpenVPNOptions) error {
// We expect the credentials from the API response to be encoded as the direct PEM serialization.
apiCreds, err := m.FetchProviderCredentials(ctx, sess, provider)
// TODO(ainghazal): validate credentials have the info we expect, certs are not expired etc.
if err != nil {
sess.Logger().Warnf("Error fetching credentials from API: %s", err.Error())
return err
}
return false
sess.Logger().Infof("Got credentials from provider: %s", provider)

opts.CA = []byte(apiCreds.Config.CA)
opts.Cert = []byte(apiCreds.Config.Cert)
opts.Key = []byte(apiCreds.Config.Key)
return nil
}

// GetCredentialsFromOptionsOrAPI attempts to find valid credentials for the given provider, either
// from the passed Options (cli, oonirun), or from a remote call to the OONI API endpoint.
func (m *Measurer) GetCredentialsFromOptionsOrAPI(
ctx context.Context,
sess model.ExperimentSession,
provider string) (*vpnconfig.OpenVPNOptions, error) {

method, ok := providerAuthentication[provider]
if !ok {
return nil, fmt.Errorf("%w: provider auth unknown: %s", ErrInvalidInput, provider)
}

// Empty options object to fill with credentials.
creds := &vpnconfig.OpenVPNOptions{}

switch method {
case AuthCertificate:
ok, err := MaybeGetCredentialsFromOptions(m.config, creds, method)
if err != nil {
return nil, err
}
if ok {
return creds, nil
}
// No options passed, so let's get the credentials that inputbuilder should have cached
// for us after hitting the OONI API.
if err := m.getCredentialsFromAPI(ctx, sess, provider, creds); err != nil {
return nil, err
}
return creds, nil

default:
return nil, fmt.Errorf("%w: method not implemented (%s)", ErrInvalidInput, method)
}

}

// Run implements model.ExperimentMeasurer.Run.
Expand All @@ -125,25 +239,9 @@ func (m Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error {
measurement := args.Measurement
sess := args.Session

var endpoint *endpoint

if measurement.Input != "" {
if ok := isValidProtocol(string(measurement.Input)); !ok {
return ErrInvalidInput
}
var err error
endpoint, err = newEndpointFromInputString(string(measurement.Input))
if err != nil {
return err
}
} else {
// if input is null, we get one from the hardcoded list of inputs.
// TODO(ainghazal): can input be empty at this stage?
// InputPolicy should ensure we have a hardcoded input,
// so this is probably non-reachable code. Move the shuffling there.
sess.Logger().Info("No input given, picking one hardcoded endpoint at random")
endpoint = DefaultEndpoints.Shuffle()[0]
measurement.Input = model.MeasurementTarget(endpoint.AsInputURI())
endpoint, err := parseEndpoint(measurement)
if err != nil {
return err
}

tk := NewTestKeys()
Expand Down Expand Up @@ -174,65 +272,6 @@ func (m Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error {
return nil
}

// getCredentialsFromOptionsOrAPI attempts to find valid credentials for the given provider, either
// from the passed Options (cli, oonirun), or from a remote call to the OONI API endpoint.
func (m *Measurer) getCredentialsFromOptionsOrAPI(
ctx context.Context,
sess model.ExperimentSession,
provider string) (*vpnconfig.OpenVPNOptions, error) {

// TODO(ainghazal): Ideally, we need to know which authentication methods each provider uses, and this is
// information that the experiment could hardcode. Sticking to Certificate-based auth for riseupvpn.

// get an empty options object to fill with credentials
creds := &vpnconfig.OpenVPNOptions{}

cfg := m.config

if cfg.SafeCA != "" && cfg.SafeCert != "" && cfg.SafeKey != "" {
// We override authentication info with what user provided in options.
// We expect the options to be encoded in base64 so that a single optin can be safely represented as command line options.
ca, err := extractBase64Blob(cfg.SafeCA)
if err != nil {
return nil, err
}
creds.CA = []byte(ca)

key, err := extractBase64Blob(cfg.SafeKey)
if err != nil {
return nil, err
}
creds.Key = []byte(key)

cert, err := extractBase64Blob(cfg.SafeCert)
if err != nil {
return nil, err
}
creds.Key = []byte(cert)

// return options-based credentials
return creds, nil
}

// No options passed, so let's get the credentials that inputbuilder should have cached
// for us after hitting the OONI API.
// We expect the credentials from the API response to be encoded as the direct PEM serialization.
apiCreds, err := m.fetchProviderCredentials(ctx, sess, provider)
// TODO(ainghazal): validate credentials have the info we expect, certs are not expired etc.

if err != nil {
sess.Logger().Warnf("Error fetching credentials from API: %s", err.Error())
return nil, err
}
sess.Logger().Infof("Got credentials from provider: %s", provider)

creds.CA = []byte(apiCreds.Config.CA)
creds.Cert = []byte(apiCreds.Config.Cert)
creds.Key = []byte(apiCreds.Config.Key)

return creds, nil
}

// connectAndHandshake dials a connection and attempts an OpenVPN handshake using that dialer.
func (m *Measurer) connectAndHandshake(ctx context.Context, index int64, zeroTime time.Time, sess model.ExperimentSession, endpoint *endpoint) (*SingleConnection, error) {

Expand All @@ -241,13 +280,13 @@ func (m *Measurer) connectAndHandshake(ctx context.Context, index int64, zeroTim
// create a trace for the network dialer
trace := measurexlite.NewTrace(index, zeroTime)

// TODO(ainghazal): can I pass tags to this tracer?
dialer := trace.NewDialerWithoutResolver(logger)

// create a vpn tun Device that attempts to dial and performs the handshake
handshakeTracer := vpntracex.NewTracerWithTransactionID(zeroTime, index)

credentials, err := m.getCredentialsFromOptionsOrAPI(ctx, sess, endpoint.Provider)
// TODO -- move to outer function ------
credentials, err := m.GetCredentialsFromOptionsOrAPI(ctx, sess, endpoint.Provider)
if err != nil {
return nil, err
}
Expand All @@ -256,6 +295,7 @@ func (m *Measurer) connectAndHandshake(ctx context.Context, index int64, zeroTim
if err != nil {
return nil, err
}
// TODO -- move to outer function ------

tun, err := tunnel.Start(ctx, dialer, openvpnConfig)

Expand Down Expand Up @@ -308,7 +348,7 @@ func (m *Measurer) connectAndHandshake(ctx context.Context, index int64, zeroTim
}

// TODO: get cached from session instead of fetching every time
func (m *Measurer) fetchProviderCredentials(
func (m *Measurer) FetchProviderCredentials(
ctx context.Context,
sess model.ExperimentSession,
provider string) (*model.OOAPIVPNProviderConfig, error) {
Expand Down
Loading

0 comments on commit 2fc7f09

Please sign in to comment.