Skip to content

Commit

Permalink
provider tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ainghazal committed Apr 1, 2024
1 parent 0708d50 commit 19b3fa9
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 38 deletions.
24 changes: 10 additions & 14 deletions internal/experiment/openvpn/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ func newEndpointFromInputString(uri string) (*endpoint, error) {
}

address := params.Get("address")
if provider == "" {
return nil, fmt.Errorf("%w: please specify a provider as part of the input", ErrInvalidInput)
if address == "" {
return nil, fmt.Errorf("%w: please specify an address as part of the input", ErrInvalidInput)
}
ip, port, err := net.SplitHostPort(address)
if err != nil {
Expand All @@ -102,7 +102,8 @@ func newEndpointFromInputString(uri string) (*endpoint, error) {
}

// String implements Stringer. This is a compact representation of the endpoint,
// which differs from the input URI scheme.
// which differs from the input URI scheme. This is the canonical representation, that can be used
// to deterministically slice a list of endpoints, sort them lexicographically, etc.
func (e *endpoint) String() string {
var proto string
if e.Obfuscation == "obfs4" {
Expand Down Expand Up @@ -179,19 +180,16 @@ func isValidProvider(provider string) bool {
return ok
}

// getVPNConfig gets a properly configured [*vpnconfig.Config] object for the given endpoint.
// getOpenVPNConfig gets a properly configured [*vpnconfig.Config] object for the given endpoint.
// To obtain that, we merge the endpoint specific configuration with base options.
// These base options are for the moment hardcoded. In the future we will want to be smarter
// about getting information for different providers.
func getVPNConfig(tracer *vpntracex.Tracer, endpoint *endpoint, creds *vpnconfig.OpenVPNOptions) (*vpnconfig.Config, error) {

// Base options are hardcoded for the moment, for comparability among different providers.
// We can add them to the OONI API and as extra cli options if ever needed.
func getOpenVPNConfig(tracer *vpntracex.Tracer, endpoint *endpoint, creds *vpnconfig.OpenVPNOptions) (*vpnconfig.Config, error) {
// TODO(ainghazal): use merge ability in vpnconfig.OpenVPNOptions merge (pending PR)

provider := endpoint.Provider
if !isValidProvider(provider) {
return nil, fmt.Errorf("%w: unknown provider: %s", ErrInvalidInput, provider)
}

baseOptions := defaultOptionsByProvider[provider]

cfg := vpnconfig.NewConfig(
Expand All @@ -214,10 +212,11 @@ func getVPNConfig(tracer *vpntracex.Tracer, endpoint *endpoint, creds *vpnconfig
),
vpnconfig.WithHandshakeTracer(tracer))

// TODO: validate options here and return an error.
// TODO: sanity check (Remote, Port, Proto etc + missing certs)
return cfg, nil
}

// extractBase64Blob is used to pass credentials as command-line options.
func extractBase64Blob(val string) (string, error) {
s := strings.TrimPrefix(val, "base64:")
if len(s) == len(val) {
Expand All @@ -227,8 +226,5 @@ func extractBase64Blob(val string) (string, error) {
if err != nil {
return "", fmt.Errorf("%w: %s", ErrBadBase64Blob, err)
}
if len(dec) == 0 {
return "", nil
}
return string(dec), nil
}
241 changes: 240 additions & 1 deletion internal/experiment/openvpn/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ package openvpn

import (
"errors"
"sort"
"testing"
"time"

"github.com/google/go-cmp/cmp"
vpnconfig "github.com/ooni/minivpn/pkg/config"
vpntracex "github.com/ooni/minivpn/pkg/tracex"
)

func Test_newEndpointFromInputString(t *testing.T) {
Expand All @@ -30,6 +34,25 @@ func Test_newEndpointFromInputString(t *testing.T) {
},
wantErr: nil,
},
{
name: "bad url fails",
args: args{"://address=1.1.1.1:1194&transport=tcp"},
want: nil,
wantErr: ErrInvalidInput,
},
{
name: "openvpn+obfs4 does not fail",
args: args{"openvpn+obfs4://riseup.corp/?address=1.1.1.1:1194&transport=tcp"},
want: &endpoint{
IPAddr: "1.1.1.1",
Obfuscation: "obfs4",
Port: "1194",
Protocol: "openvpn",
Provider: "riseup",
Transport: "tcp",
},
wantErr: nil,
},
{
name: "unknown proto fails",
args: args{"unknown://riseup.corp/?address=1.1.1.1:1194&transport=tcp"},
Expand All @@ -42,6 +65,12 @@ func Test_newEndpointFromInputString(t *testing.T) {
want: nil,
wantErr: ErrInvalidInput,
},
{
name: "empty provider fails",
args: args{"openvpn://.corp/?address=1.1.1.1:1194&transport=tcp"},
want: nil,
wantErr: ErrInvalidInput,
},
{
name: "non-registered provider fails",
args: args{"openvpn://nsavpn.corp/?address=1.1.1.1:1194&transport=tcp"},
Expand Down Expand Up @@ -78,6 +107,18 @@ func Test_newEndpointFromInputString(t *testing.T) {
want: nil,
wantErr: ErrInvalidInput,
},
{
name: "endpoint with no address fails",
args: args{"openvpn://riseup.corp/?transport=tcp"},
want: nil,
wantErr: ErrInvalidInput,
},
{
name: "endpoint with empty address fails",
args: args{"openvpn://riseup.corp/?address=&transport=tcp"},
want: nil,
wantErr: ErrInvalidInput,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -130,6 +171,20 @@ func Test_EndpointToInputURI(t *testing.T) {
},
want: "openvpn+obfs4://shady.corp/?address=1.1.1.1:443&transport=udp",
},
{
name: "empty provider is marked as unknown",
args: args{
endpoint{
IPAddr: "1.1.1.1",
Obfuscation: "obfs4",
Port: "443",
Protocol: "openvpn",
Provider: "",
Transport: "udp",
},
},
want: "openvpn+obfs4://unknown.corp/?address=1.1.1.1:443&transport=udp",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -140,4 +195,188 @@ func Test_EndpointToInputURI(t *testing.T) {
}
}

// TODO: test the endpoint uri string too.
func Test_endpoint_String(t *testing.T) {
type fields struct {
IPAddr string
Obfuscation string
Port string
Protocol string
Provider string
Transport string
}
tests := []struct {
name string
fields fields
want string
}{
{
name: "well formed endpoint returns a well formed endpoint string",
fields: fields{
IPAddr: "1.1.1.1",
Obfuscation: "none",
Port: "1194",
Protocol: "openvpn",
Provider: "unknown",
Transport: "tcp",
},
want: "openvpn://1.1.1.1:1194/tcp",
},
{
name: "well formed endpoint, openvpn+obfs4",
fields: fields{
IPAddr: "1.1.1.1",
Obfuscation: "obfs4",
Port: "1194",
Protocol: "openvpn",
Provider: "unknown",
Transport: "tcp",
},
want: "openvpn+obfs4://1.1.1.1:1194/tcp",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &endpoint{
IPAddr: tt.fields.IPAddr,
Obfuscation: tt.fields.Obfuscation,
Port: tt.fields.Port,
Protocol: tt.fields.Protocol,
Provider: tt.fields.Provider,
Transport: tt.fields.Transport,
}
if got := e.String(); got != tt.want {
t.Errorf("endpoint.String() = %v, want %v", got, tt.want)
}
})
}
}

func Test_endpointList_Shuffle(t *testing.T) {
shuffled := DefaultEndpoints.Shuffle()
sort.Slice(shuffled, func(i, j int) bool {
return shuffled[i].IPAddr < shuffled[j].IPAddr
})
if diff := cmp.Diff(shuffled, DefaultEndpoints); diff != "" {
t.Error(diff)
}
}

func Test_isValidProvider(t *testing.T) {
if valid := isValidProvider("riseup"); !valid {
t.Fatal("riseup is the only valid provider now")
}
if valid := isValidProvider("nsa"); valid {
t.Fatal("nsa will nevel be a provider")
}
}

func Test_getVPNConfig(t *testing.T) {
tracer := vpntracex.NewTracer(time.Now())
e := &endpoint{
Provider: "riseup",
IPAddr: "1.1.1.1",
Port: "443",
Transport: "udp",
}
creds := &vpnconfig.OpenVPNOptions{
CA: []byte("ca"),
Cert: []byte("cert"),
Key: []byte("key"),
}

cfg, err := getOpenVPNConfig(tracer, e, creds)
if err != nil {
t.Fatalf("did not expect error, got: %v", err)
}
if cfg.Tracer() != tracer {
t.Fatal("config tracer is not what passed")
}
if auth := cfg.OpenVPNOptions().Auth; auth != "SHA512" {
t.Errorf("expected auth %s, got %s", "SHA512", auth)
}
if cipher := cfg.OpenVPNOptions().Cipher; cipher != "AES-256-GCM" {
t.Errorf("expected cipher %s, got %s", "AES-256-GCM", cipher)
}
if remote := cfg.OpenVPNOptions().Remote; remote != e.IPAddr {
t.Errorf("expected remote %s, got %s", e.IPAddr, remote)
}
if port := cfg.OpenVPNOptions().Port; port != e.Port {
t.Errorf("expected port %s, got %s", e.Port, port)
}
if transport := cfg.OpenVPNOptions().Proto; string(transport) != e.Transport {
t.Errorf("expected transport %s, got %s", e.Transport, transport)
}
if transport := cfg.OpenVPNOptions().Proto; string(transport) != e.Transport {
t.Errorf("expected transport %s, got %s", e.Transport, transport)
}
if diff := cmp.Diff(cfg.OpenVPNOptions().CA, creds.CA); diff != "" {
t.Error(diff)
}
if diff := cmp.Diff(cfg.OpenVPNOptions().Cert, creds.Cert); diff != "" {
t.Error(diff)
}
if diff := cmp.Diff(cfg.OpenVPNOptions().Key, creds.Key); diff != "" {
t.Error(diff)
}
}

func Test_getVPNConfig_with_unknown_provider(t *testing.T) {
tracer := vpntracex.NewTracer(time.Now())
e := &endpoint{
Provider: "nsa",
IPAddr: "1.1.1.1",
Port: "443",
Transport: "udp",
}
creds := &vpnconfig.OpenVPNOptions{
CA: []byte("ca"),
Cert: []byte("cert"),
Key: []byte("key"),
}
_, err := getOpenVPNConfig(tracer, e, creds)
if !errors.Is(err, ErrInvalidInput) {
t.Fatalf("expected invalid input error, got: %v", err)
}

}

func Test_extractBase64Blob(t *testing.T) {
t.Run("decode good blob", func(t *testing.T) {
blob := "base64:dGhlIGJsdWUgb2N0b3B1cyBpcyB3YXRjaGluZw=="
decoded, err := extractBase64Blob(blob)
if decoded != "the blue octopus is watching" {
t.Fatal("could not decoded blob correctly")
}
if err != nil {
t.Fatal("should not fail with first blob")
}
})
t.Run("try decode without prefix", func(t *testing.T) {
blob := "dGhlIGJsdWUgb2N0b3B1cyBpcyB3YXRjaGluZw=="
_, err := extractBase64Blob(blob)
if !errors.Is(err, ErrBadBase64Blob) {
t.Fatal("should fail without prefix")
}
})
t.Run("bad base64 blob should fail", func(t *testing.T) {
blob := "base64:dGhlIGJsdWUgb2N0b3B1cyBpcyB3YXRjaGluZw"
_, err := extractBase64Blob(blob)
if !errors.Is(err, ErrBadBase64Blob) {
t.Fatal("bad blob should fail without prefix")
}
})
t.Run("decode empty blob", func(t *testing.T) {
blob := "base64:"
_, err := extractBase64Blob(blob)
if err != nil {
t.Fatal("empty blob should not fail")
}
})
t.Run("illegal base64 data should fail", func(t *testing.T) {
blob := "base64:=="
_, err := extractBase64Blob(blob)
if !errors.Is(err, ErrBadBase64Blob) {
t.Fatal("bad base64 data should fail")
}
})
}
2 changes: 1 addition & 1 deletion internal/experiment/openvpn/openvpn.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (m *Measurer) connectAndHandshake(ctx context.Context, index int64, zeroTim
return nil, err
}

openvpnConfig, err := getVPNConfig(handshakeTracer, endpoint, credentials)
openvpnConfig, err := getOpenVPNConfig(handshakeTracer, endpoint, credentials)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 19b3fa9

Please sign in to comment.