From d9be6fe233d9ddbcd71ffc8c00ea21f30c11bbb3 Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Mon, 12 Feb 2024 09:40:24 +0100 Subject: [PATCH] Properly validate configuration (#2) * Properly validate configuration * add tests * add more unit tests * make format --------- Co-authored-by: Sergi Castro --- Makefile | 2 +- internal/config.go | 9 +- internal/config_test.go | 74 ++++++++-- internal/logging_test.go | 11 ++ internal/server/authz_test.go | 204 ++++++++++++++++++++++++++ internal/server/server.go | 3 - internal/server/server_test.go | 89 +++++++---- internal/testdata/invalid-config.json | 3 + internal/testdata/invalid-values.json | 17 +++ internal/testdata/mock.json | 2 + 10 files changed, 366 insertions(+), 48 deletions(-) create mode 100644 internal/server/authz_test.go create mode 100644 internal/testdata/invalid-config.json create mode 100644 internal/testdata/invalid-values.json diff --git a/Makefile b/Makefile index 16dcc0f..e7f202f 100644 --- a/Makefile +++ b/Makefile @@ -102,7 +102,7 @@ config/lint: ## Lint the Config Proto generated code test: ## Run all the tests @go test $(TEST_OPTS) $(TEST_PKGS) -COVERAGE_OPTS ?= +COVERAGE_OPTS ?= .PHONY: coverage coverage: ## Creates coverage report for all projects @echo "Running test coverage" diff --git a/internal/config.go b/internal/config.go index 2219d70..2e74d8c 100644 --- a/internal/config.go +++ b/internal/config.go @@ -57,5 +57,12 @@ func (l *LocalConfigFile) Validate() error { return err } - return protojson.Unmarshal(content, &l.Config) + if err = protojson.Unmarshal(content, &l.Config); err != nil { + return err + } + + // Set reasonable defaults for non-supported values + l.Config.Threads = 1 + + return l.Config.ValidateAll() } diff --git a/internal/config_test.go b/internal/config_test.go index 3ab2aef..917d2f2 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -19,33 +19,79 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/tetratelabs/run" + "github.com/tetratelabs/telemetry" + "google.golang.org/protobuf/proto" + + configv1 "github.com/tetrateio/authservice-go/config/gen/go/v1" + mockv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/mock" ) +type errCheck struct { + is error + as error + msg string +} + +func (e errCheck) Check(t *testing.T, err error) { + switch { + case e.as != nil: + require.ErrorAs(t, err, &e.as) + case e.msg != "": + require.ErrorContains(t, err, e.msg) + default: + require.ErrorIs(t, err, e.is) + } +} + func TestLoadConfig(t *testing.T) { tests := []struct { - name string - path string - err error + name string + path string + check errCheck }{ - {"empty", "", ErrInvalidPath}, - {"invalid", "unexisting", os.ErrNotExist}, - {"valid", "testdata/mock.json", nil}, + {"empty", "", errCheck{is: ErrInvalidPath}}, + {"unexisting", "unexisting", errCheck{is: os.ErrNotExist}}, + {"invalid-config", "testdata/invalid-config.json", errCheck{msg: `unknown field "foo"`}}, + {"invalid-values", "testdata/invalid-values.json", errCheck{as: &configv1.ConfigMultiError{}}}, + {"valid", "testdata/mock.json", errCheck{is: nil}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfg := LocalConfigFile{path: tt.path} - require.ErrorIs(t, cfg.Validate(), tt.err) + err := (&LocalConfigFile{path: tt.path}).Validate() + tt.check.Check(t, err) }) } } func TestLoadMock(t *testing.T) { - cfg := LocalConfigFile{path: "testdata/mock.json"} + want := &configv1.Config{ + ListenAddress: "0.0.0.0", + ListenPort: 8080, + LogLevel: "debug", + Threads: 1, + Chains: []*configv1.FilterChain{ + { + Name: "mock", + Filters: []*configv1.Filter{ + { + Type: &configv1.Filter_Mock{ + Mock: &mockv1.MockConfig{ + Allow: true, + }, + }, + }, + }, + }, + }, + } + + var cfg LocalConfigFile + g := run.Group{Logger: telemetry.NoopLogger()} + g.Register(&cfg) + err := g.Run("", "--config-path", "testdata/mock.json") - require.NoError(t, cfg.Validate()) - require.Len(t, cfg.Config.Chains, 1) - require.Equal(t, "mock", cfg.Config.Chains[0].Name) - require.Len(t, cfg.Config.Chains[0].Filters, 1) - require.True(t, cfg.Config.Chains[0].Filters[0].GetMock().Allow) + require.NoError(t, err) + require.True(t, proto.Equal(want, &cfg.Config)) } diff --git a/internal/logging_test.go b/internal/logging_test.go index a1b080f..8aacfc7 100644 --- a/internal/logging_test.go +++ b/internal/logging_test.go @@ -25,6 +25,15 @@ import ( configv1 "github.com/tetrateio/authservice-go/config/gen/go/v1" ) +func TestGetLogger(t *testing.T) { + l1 := scope.Register("l1", "test logger one") + + NewLogSystem(telemetry.NoopLogger(), nil) + + require.Equal(t, l1, Logger("l1")) + require.Equal(t, telemetry.NoopLogger(), Logger("l2")) +} + func TestLoggingSetup(t *testing.T) { l1 := scope.Register("l1", "test logger one") l2 := scope.Register("l2", "test logger two") @@ -50,6 +59,8 @@ func TestLoggingSetup(t *testing.T) { {",", telemetry.LevelInfo, telemetry.LevelInfo, true}, {":", telemetry.LevelInfo, telemetry.LevelInfo, true}, {"invalid", telemetry.LevelInfo, telemetry.LevelInfo, true}, + {"l1:,l2:info", telemetry.LevelInfo, telemetry.LevelInfo, true}, + {"l1:debug,l2:invalid", telemetry.LevelInfo, telemetry.LevelInfo, true}, } for _, tt := range tests { diff --git a/internal/server/authz_test.go b/internal/server/authz_test.go new file mode 100644 index 0000000..39a3e41 --- /dev/null +++ b/internal/server/authz_test.go @@ -0,0 +1,204 @@ +// Copyright 2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Copyright (c) Tetrate, Inc 2024 All Rights Reserved. + +package server + +import ( + "context" + "testing" + + envoy "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + + configv1 "github.com/tetrateio/authservice-go/config/gen/go/v1" + mockv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/mock" +) + +func TestUnmatchedRequests(t *testing.T) { + tests := []struct { + name string + allow bool + want codes.Code + }{ + {"allow", true, codes.OK}, + {"deny", false, codes.PermissionDenied}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewExtAuthZFilter(&configv1.Config{AllowUnmatchedRequests: tt.allow}) + got, err := e.Check(context.Background(), &envoy.CheckRequest{}) + require.NoError(t, err) + require.Equal(t, int32(tt.want), got.Status.Code) + }) + } +} + +func TestFiltersMatch(t *testing.T) { + tests := []struct { + name string + filters []*configv1.Filter + want codes.Code + }{ + {"no-filters", nil, codes.OK}, + {"all-filters-match", []*configv1.Filter{mock(true), mock(true)}, codes.OK}, + {"one-filter-deny", []*configv1.Filter{mock(true), mock(false)}, codes.PermissionDenied}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &configv1.Config{Chains: []*configv1.FilterChain{{Filters: tt.filters}}} + e := NewExtAuthZFilter(cfg) + + got, err := e.Check(context.Background(), &envoy.CheckRequest{}) + require.NoError(t, err) + require.Equal(t, int32(tt.want), got.Status.Code) + }) + } +} + +func TestUseFirstMatchingChain(t *testing.T) { + cfg := &configv1.Config{ + Chains: []*configv1.FilterChain{ + { + // Chain to be ignored + Match: eq("no-match"), + Filters: []*configv1.Filter{mock(false)}, + }, + { + // Chain to be used + Match: eq("match"), + Filters: []*configv1.Filter{mock(true)}, + }, + { + // Always matches but should not be used as the previous + // chain already matched + Filters: []*configv1.Filter{mock(false)}, + }, + }, + } + + e := NewExtAuthZFilter(cfg) + + got, err := e.Check(context.Background(), header("match")) + require.NoError(t, err) + require.Equal(t, int32(codes.OK), got.Status.Code) +} + +func TestCheckMock(t *testing.T) { + tests := []struct { + name string + allow bool + want bool + }{ + {"allow", true, true}, + {"deny", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &ExtAuthZFilter{} + got, err := e.checkMock( + context.Background(), + &envoy.CheckRequest{}, + &mockv1.MockConfig{Allow: tt.allow}, + ) + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } + +} + +func TestMatch(t *testing.T) { + tests := []struct { + name string + match *configv1.Match + req *envoy.CheckRequest + want bool + }{ + {"no-headers", eq("test"), &envoy.CheckRequest{}, false}, + {"no-match-condition", nil, &envoy.CheckRequest{}, true}, + {"equality-match", eq("test"), header("test"), true}, + {"equality-no-match", eq("test"), header("no-match"), false}, + {"prefix-match", prefix("test"), header("test-123"), true}, + {"prefix-no-match", prefix("test"), header("no-match"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, matches(tt.match, tt.req)) + }) + } +} + +func TestGrpcNoChainsMatched(t *testing.T) { + e := NewExtAuthZFilter(&configv1.Config{}) + s := NewTestServer(e.Register) + go func() { require.NoError(t, s.Start()) }() + t.Cleanup(s.Stop) + + conn, err := s.GRPCConn() + require.NoError(t, err) + client := envoy.NewAuthorizationClient(conn) + + ok, err := client.Check(context.Background(), &envoy.CheckRequest{}) + require.NoError(t, err) + require.Equal(t, int32(codes.PermissionDenied), ok.Status.Code) +} + +func mock(allow bool) *configv1.Filter { + return &configv1.Filter{ + Type: &configv1.Filter_Mock{ + Mock: &mockv1.MockConfig{ + Allow: allow, + }, + }, + } +} + +func eq(value string) *configv1.Match { + return &configv1.Match{ + Header: "X-Test-Headers", + Criteria: &configv1.Match_Equality{ + Equality: value, + }, + } +} + +func prefix(value string) *configv1.Match { + return &configv1.Match{ + Header: "X-Test-Headers", + Criteria: &configv1.Match_Prefix{ + Prefix: value, + }, + } +} + +func header(value string) *envoy.CheckRequest { + return &envoy.CheckRequest{ + Attributes: &envoy.AttributeContext{ + Request: &envoy.AttributeContext_Request{ + Http: &envoy.AttributeContext_HttpRequest{ + Headers: map[string]string{ + "x-test-headers": value, + }, + }, + }, + }, + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 1dfed6c..e153432 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -15,7 +15,6 @@ package server import ( - "errors" "fmt" "net" @@ -38,8 +37,6 @@ var ( _ run.Service = (*Server)(nil) ) -var ErrInvalidAddress = errors.New("invalid address") - // Server that runs as a unit in a run.Group. type Server struct { log telemetry.Logger diff --git a/internal/server/server_test.go b/internal/server/server_test.go index b8ee70e..021941a 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -16,8 +16,8 @@ package server import ( "context" + "errors" "net" - "sync" "testing" "github.com/stretchr/testify/require" @@ -31,40 +31,71 @@ import ( "google.golang.org/grpc/test/bufconn" ) -func TestServer(t *testing.T) { +func TestGrpcServer(t *testing.T) { + s := NewTestServer(func(s *grpc.Server) { + testgrpc.RegisterTestServiceServer(s, interop.NewTestServer()) + }) + go func() { require.NoError(t, s.Start()) }() + t.Cleanup(s.Stop) + + conn, err := s.GRPCConn() + require.NoError(t, err) + + client := testgrpc.NewTestServiceClient(conn) + interop.DoEmptyUnaryCall(client) // this method will panic if fails +} + +func TestListenFails(t *testing.T) { + err := errors.New("listen failed") + s := New(nil) + s.Listen = func() (net.Listener, error) { return nil, err } + require.ErrorIs(t, s.Serve(), err) +} + +// TestServer that uses an in-memory listener for connections. +type TestServer struct { + g run.Group + l *bufconn.Listener + dialOpts []grpc.DialOption + shutdown func() +} + +// NewTestServer creates a new test server. +func NewTestServer(handlers ...func(s *grpc.Server)) *TestServer { var ( - g = run.Group{Logger: telemetry.NoopLogger()} - irq = test.NewIRQService(func() {}) - l = bufconn.Listen(1024) - s = New(nil, func(s *grpc.Server) { - testgrpc.RegisterTestServiceServer(s, interop.NewTestServer()) - }) + g = run.Group{Logger: telemetry.NoopLogger()} + irq = test.NewIRQService(func() {}) + l = bufconn.Listen(1024) + dialOpts = []grpc.DialOption{ + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return l.Dial() }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + } + + s = New(nil, handlers...) ) - s.log = telemetry.NoopLogger() + s.Listen = func() (net.Listener, error) { return l, nil } g.Register(s, irq) - // Start the server - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - require.NoError(t, g.Run()) - wg.Done() - }() - - conn, err := grpc.Dial("bufnet", - grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { return l.Dial() }), - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, conn.Close()) }) + return &TestServer{ + g: g, + l: l, + dialOpts: dialOpts, + shutdown: func() { _ = irq.Close() }, + } +} - client := testgrpc.NewTestServiceClient(conn) - interop.DoEmptyUnaryCall(client) // this method will panic if fails +// GRPCConn returns a gRPC connection that connects to the test server. +func (s *TestServer) GRPCConn() (*grpc.ClientConn, error) { + return grpc.Dial("bufnet", s.dialOpts...) +} - // Signal server termination - require.NoError(t, irq.Close()) +// Start starts the server. This blocks until the server is stopped. +func (s *TestServer) Start() error { + return s.g.Run() +} - // Wait for the server to stop - wg.Wait() +// Stop the test server. +func (s *TestServer) Stop() { + s.shutdown() } diff --git a/internal/testdata/invalid-config.json b/internal/testdata/invalid-config.json new file mode 100644 index 0000000..c8c4105 --- /dev/null +++ b/internal/testdata/invalid-config.json @@ -0,0 +1,3 @@ +{ + "foo": "bar" +} diff --git a/internal/testdata/invalid-values.json b/internal/testdata/invalid-values.json new file mode 100644 index 0000000..85b7a26 --- /dev/null +++ b/internal/testdata/invalid-values.json @@ -0,0 +1,17 @@ +{ + "listen_address": "INVALID", + "listen_port": 999999999, + "log_level": "debug", + "chains": [ + { + "name": "mock", + "filters": [ + { + "mock": { + "allow": true + } + } + ] + } + ] +} diff --git a/internal/testdata/mock.json b/internal/testdata/mock.json index eee0934..45f6b2e 100644 --- a/internal/testdata/mock.json +++ b/internal/testdata/mock.json @@ -1,4 +1,6 @@ { + "listen_address": "0.0.0.0", + "listen_port": 8080, "log_level": "debug", "chains": [ {