From 1e9ecdfc38b17551f2e23f8c5f70b7e6908b21be Mon Sep 17 00:00:00 2001 From: Ignasi Barrera Date: Thu, 15 Feb 2024 08:35:10 +0100 Subject: [PATCH] Added token response persistence in the Redis store (#12) --- .github/codecov.yml | 2 - .github/workflows/ci.yaml | 6 +- Makefile | 5 + e2e/Makefile | 2 +- e2e/mock/Makefile | 32 +---- e2e/mock/docker-compose.yaml | 4 +- e2e/redis/Makefile | 15 +++ e2e/redis/docker-compose.yaml | 22 ++++ e2e/redis/store_test.go | 77 +++++++++++ e2e/suite.mk | 49 +++++++ go.mod | 6 + go.sum | 22 ++++ internal/authz/mock_test.go | 3 +- internal/authz/oidc_test.go | 28 ++++ internal/config.go | 26 ++-- internal/config_test.go | 20 +-- internal/oidc/memory.go | 35 +++-- internal/oidc/memory_test.go | 68 ++++++---- internal/oidc/redis.go | 187 ++++++++++++++++++++++++++- internal/oidc/redis_test.go | 147 +++++++++++++++++++++ internal/oidc/session.go | 33 +++-- internal/oidc/session_test.go | 44 ++++++- internal/oidc/token.go | 20 ++- internal/oidc/token_test.go | 52 ++++++++ internal/server/logging.go | 4 +- internal/testdata/invalid-redis.json | 30 +++++ internal/testdata/oidc-override.json | 3 + internal/testdata/oidc.json | 3 + 28 files changed, 821 insertions(+), 124 deletions(-) create mode 100644 e2e/redis/Makefile create mode 100644 e2e/redis/docker-compose.yaml create mode 100644 e2e/redis/store_test.go create mode 100644 e2e/suite.mk create mode 100644 internal/authz/oidc_test.go create mode 100644 internal/oidc/redis_test.go create mode 100644 internal/oidc/token_test.go create mode 100644 internal/testdata/invalid-redis.json diff --git a/.github/codecov.yml b/.github/codecov.yml index b1e997a..d55b309 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -14,11 +14,9 @@ ignore: coverage: status: - # require coverage to not be worse than previously project: default: target: auto - # allow a potential drop of up to 5% threshold: 5% patch: default: diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d684518..30560a4 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -62,7 +62,7 @@ jobs: with: go-version-file: go.mod - run: make coverage - - uses: codecov/codecov-action@v3 + - uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} @@ -70,10 +70,10 @@ jobs: needs: check runs-on: ubuntu-latest steps: - - uses: docker/setup-qemu-action@v2 + - uses: docker/setup-qemu-action@v3 with: platforms: amd64,arm64 - - uses: docker/setup-buildx-action@v2 + - uses: docker/setup-buildx-action@v3 - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: diff --git a/Makefile b/Makefile index e7f202f..55865c5 100644 --- a/Makefile +++ b/Makefile @@ -118,6 +118,11 @@ coverage: ## Creates coverage report for all projects e2e: ## Runt he e2e tests @$(MAKE) -C e2e e2e +e2e/%: force-e2e + @$(MAKE) -C e2e $(@) + +.PHONY: force-e2e +force-e2e: ##@ Docker targets diff --git a/e2e/Makefile b/e2e/Makefile index ea3f405..25c8da8 100644 --- a/e2e/Makefile +++ b/e2e/Makefile @@ -13,7 +13,7 @@ # limitations under the License. -SUITES := mock +SUITES := mock redis .PHONY: e2e e2e: $(SUITES:%=e2e/%) ## Run all e2e tests diff --git a/e2e/mock/Makefile b/e2e/mock/Makefile index 3d65cd5..66ab349 100644 --- a/e2e/mock/Makefile +++ b/e2e/mock/Makefile @@ -12,34 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Force run of the e2e tests -E2E_TEST_OPTS ?= -count=1 - - -.PHONY: e2e -e2e: e2e-pre - @$(MAKE) e2e-test e2e-post - -.PHONY: e2e-test -e2e-test: - @go test $(E2E_TEST_OPTS) ./... || ( $(MAKE) e2e-post-error; exit 1 ) - -.PHONY: e2e-pre -e2e-pre: - @docker compose up --detach --wait --force-recreate --remove-orphans || ($(MAKE) e2e-post-error; exit 1) - -.PHONY: e2e-post -e2e-post: - @docker compose down - -.PHONY: e2e-post-error -e2e-post-error: capture-logs - -.PHONY: capture-logs -capture-logs: - @mkdir -p ./logs - @docker compose logs > logs/docker-compose-logs.log - -.PHONY: clean -clean: - @rm -rf ./logs +include ../suite.mk diff --git a/e2e/mock/docker-compose.yaml b/e2e/mock/docker-compose.yaml index c1bb9b1..dc74c4d 100644 --- a/e2e/mock/docker-compose.yaml +++ b/e2e/mock/docker-compose.yaml @@ -17,10 +17,10 @@ version: "3.9" services: envoy: image: envoyproxy/envoy:v1.29-latest - platform: linux/arm64 + platform: linux/${ARCH:-amd64} command: -c /etc/envoy/envoy-config.yaml --log-level warning ports: - - "8080:80" # Make it accessible from the host (HTTP traffic) + - "8080:80" volumes: - type: bind source: envoy-config.yaml diff --git a/e2e/redis/Makefile b/e2e/redis/Makefile new file mode 100644 index 0000000..66ab349 --- /dev/null +++ b/e2e/redis/Makefile @@ -0,0 +1,15 @@ +# Copyright 2024 Tetrate +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +include ../suite.mk diff --git a/e2e/redis/docker-compose.yaml b/e2e/redis/docker-compose.yaml new file mode 100644 index 0000000..e8a95db --- /dev/null +++ b/e2e/redis/docker-compose.yaml @@ -0,0 +1,22 @@ +# Copyright 2024 Tetrate +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +version: "3.9" + +services: + redis: + image: redis:7.2.4 + platform: linux/${ARCH:-amd64} + ports: + - "6379:6379" diff --git a/e2e/redis/store_test.go b/e2e/redis/store_test.go new file mode 100644 index 0000000..b326ea3 --- /dev/null +++ b/e2e/redis/store_test.go @@ -0,0 +1,77 @@ +// Copyright 2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "testing" + "time" + + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwt" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + + "github.com/tetrateio/authservice-go/internal/oidc" +) + +const redisURL = "redis://localhost:6379" + +func TestRedisTokenResponse(t *testing.T) { + opts, err := redis.ParseURL(redisURL) + require.NoError(t, err) + client := redis.NewClient(opts) + + store, err := oidc.NewRedisStore(&oidc.Clock{}, client, 0, 1*time.Minute) + require.NoError(t, err) + + ctx := context.Background() + + tr, err := store.GetTokenResponse(ctx, "s1") + require.NoError(t, err) + require.Nil(t, tr) + + // Create a session and verify it's added and accessed time + tr = &oidc.TokenResponse{ + IDToken: newToken(), + AccessToken: newToken(), + AccessTokenExpiresAt: time.Now().Add(30 * time.Minute), + } + require.NoError(t, store.SetTokenResponse(ctx, "s1", tr)) + + // Verify we can retrieve the token + got, err := store.GetTokenResponse(ctx, "s1") + require.NoError(t, err) + // The testify library doesn't properly compare times, so we need to do it manually + // then set the times in the returned object so that we can compare the rest of the + // fields normally + require.True(t, tr.AccessTokenExpiresAt.Equal(got.AccessTokenExpiresAt)) + got.AccessTokenExpiresAt = tr.AccessTokenExpiresAt + require.Equal(t, tr, got) + + // Verify that the token TTL has been set + ttl := client.TTL(ctx, "s1").Val() + require.Greater(t, ttl, time.Duration(0)) +} + +func newToken() string { + token, _ := jwt.NewBuilder(). + Issuer("authservice"). + Subject("user"). + Expiration(time.Now().Add(time.Hour)). + Build() + signed, _ := jwt.Sign(token, jwa.HS256, []byte("key")) + return string(signed) +} diff --git a/e2e/suite.mk b/e2e/suite.mk new file mode 100644 index 0000000..f93f5f6 --- /dev/null +++ b/e2e/suite.mk @@ -0,0 +1,49 @@ +# Copyright 2024 Tetrate +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Force run of the e2e tests +E2E_TEST_OPTS ?= -count=1 + +export ARCH := $(shell uname -m) +ifeq ($(ARCH),x86_64) +export ARCH := amd64 +endif + +.PHONY: e2e +e2e: e2e-pre + @$(MAKE) e2e-test e2e-post + +.PHONY: e2e-test +e2e-test: + @go test $(E2E_TEST_OPTS) ./... || ( $(MAKE) e2e-post-error; exit 1 ) + +.PHONY: e2e-pre +e2e-pre: + @docker compose up --detach --wait --force-recreate --remove-orphans || ($(MAKE) e2e-post-error; exit 1) + +.PHONY: e2e-post +e2e-post: + @docker compose down + +.PHONY: e2e-post-error +e2e-post-error: capture-logs + +.PHONY: capture-logs +capture-logs: + @mkdir -p ./logs + @docker compose logs > logs/docker-compose-logs.log + +.PHONY: clean +clean: + @rm -rf ./logs diff --git a/go.mod b/go.mod index 0ea7d21..a8b64a0 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,11 @@ module github.com/tetrateio/authservice-go go 1.21.6 require ( + github.com/alicebob/miniredis/v2 v2.31.1 github.com/envoyproxy/go-control-plane v0.12.0 github.com/envoyproxy/protoc-gen-validate v1.0.4 github.com/lestrrat-go/jwx v1.2.28 + github.com/redis/go-redis/v9 v9.4.0 github.com/stretchr/testify v1.8.4 github.com/tetratelabs/log v0.2.3 github.com/tetratelabs/run v0.2.2 @@ -18,9 +20,12 @@ require ( require ( cloud.google.com/go/compute v1.23.3 // indirect cloud.google.com/go/compute/metadata v0.2.3 // indirect + github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect @@ -36,6 +41,7 @@ require ( github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/tetratelabs/multierror v1.1.0 // indirect + github.com/yuin/gopher-lua v1.1.0 // indirect golang.org/x/crypto v0.18.0 // indirect golang.org/x/net v0.20.0 // indirect golang.org/x/oauth2 v0.16.0 // indirect diff --git a/go.sum b/go.sum index 6a2c8a7..8a3b23c 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,20 @@ cloud.google.com/go/compute v1.23.3 h1:6sVlXXBmbd7jNX0Ipq0trII3e4n1/MsADLK6a+aiV cloud.google.com/go/compute v1.23.3/go.mod h1:VCgBUoMnIVIR0CscqQiPJLAG25E3ZRZMzcFZeQ+h8CI= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= +github.com/DmitriyVTitov/size v1.5.0/go.mod h1:le6rNI4CoLQV1b9gzp1+3d7hMAD/uu2QcJ+aYbNgiU0= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.31.1 h1:7XAt0uUg3DtwEKW5ZAGa+K7FZV2DdKQo5K/6TTnfX8Y= +github.com/alicebob/miniredis/v2 v2.31.1/go.mod h1:UB/T2Uztp7MlFSDakaX1sTXUv5CASoprx0wulRT6HBg= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa h1:jQCWAUqqlij9Pgj2i/PB79y4KOPYVyFYdROxgaCwdTQ= github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa/go.mod h1:x/1Gn8zydmfq8dk6e9PdstVsDgu9RuyIIJqAaF//0IM= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -11,12 +25,15 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/envoyproxy/go-control-plane v0.12.0 h1:4X+VP1GHd1Mhj6IB5mMeGbLCleqxjletLK6K0rbxyZI= github.com/envoyproxy/go-control-plane v0.12.0/go.mod h1:ZBTaoJ23lqITozF0M6G4/IragXCQKCnYbmlmtHvwRG0= github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= @@ -49,6 +66,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.4.0 h1:Yzoz33UZw9I/mFhx4MNrB6Fk+XHO1VukNcCa1+lwyKk= +github.com/redis/go-redis/v9 v9.4.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= @@ -70,6 +89,8 @@ github.com/tetratelabs/run v0.2.2/go.mod h1:22PbpPVSMGbForFumdO0sbSlhxaYWopg5hkF github.com/tetratelabs/telemetry v0.8.2 h1:VXwSVpfX1yRMo6UdhsLP80GuPzavVWuoJrVM+2lMOSk= github.com/tetratelabs/telemetry v0.8.2/go.mod h1:jDUcf1A2u4F5V1io5RdipM/bKz/hFCsx/RAgGopC37s= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/gopher-lua v1.1.0 h1:BojcDhfyDWgU2f2TOzYK/g5p2gxMrku8oupLDqlnSqE= +github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= @@ -89,6 +110,7 @@ golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2 golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/authz/mock_test.go b/internal/authz/mock_test.go index 7ccbb4c..9be0309 100644 --- a/internal/authz/mock_test.go +++ b/internal/authz/mock_test.go @@ -20,7 +20,6 @@ import ( envoy "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" "github.com/stretchr/testify/require" - "github.com/tetratelabs/telemetry" "google.golang.org/grpc/codes" mockv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/mock" @@ -39,7 +38,7 @@ func TestProcessMock(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var ( - m = &mockHandler{log: telemetry.NoopLogger(), config: &mockv1.MockConfig{Allow: tt.allow}} + m = NewMockHandler(&mockv1.MockConfig{Allow: tt.allow}) req = &envoy.CheckRequest{} resp = &envoy.CheckResponse{} ) diff --git a/internal/authz/oidc_test.go b/internal/authz/oidc_test.go new file mode 100644 index 0000000..d5faf82 --- /dev/null +++ b/internal/authz/oidc_test.go @@ -0,0 +1,28 @@ +// Copyright 2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package authz + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestProcessOIDC(t *testing.T) { + h, err := NewOIDCHandler(nil, nil, nil) + require.NoError(t, err) + require.NoError(t, h.Process(context.Background(), nil, nil)) +} diff --git a/internal/config.go b/internal/config.go index 1133440..e3d24da 100644 --- a/internal/config.go +++ b/internal/config.go @@ -19,6 +19,7 @@ import ( "fmt" "os" + "github.com/redis/go-redis/v9" "github.com/tetratelabs/run" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" @@ -30,14 +31,11 @@ import ( var ( _ run.Config = (*LocalConfigFile)(nil) - // ErrInvalidPath is returned when the configuration file path is invalid. - ErrInvalidPath = errors.New("invalid path") - // ErrInvalidOIDCOverride is returned when the OIDC override is invalid. + ErrInvalidPath = errors.New("invalid path") ErrInvalidOIDCOverride = errors.New("invalid OIDC override") - // ErrDuplicateOIDCConfig is returned when the OIDC configuration is duplicated. ErrDuplicateOIDCConfig = errors.New("duplicate OIDC configuration") - // ErrMultipleOIDCConfig is returned ultiple OIDC configurations are set in the same filter chain. - ErrMultipleOIDCConfig = errors.New("multiple OIDC configurations") + ErrMultipleOIDCConfig = errors.New("multiple OIDC configurations") + ErrInvalidRedisURL = errors.New("invalid Redis URL") ) // LocalConfigFile is a run.Config that loads the configuration file. @@ -99,15 +97,17 @@ func (l *LocalConfigFile) Validate() error { // Merge the OIDC overrides with the default OIDC configuration so that // we can properly validate the settings and all filters have only one // location where the OIDC configuration is defined. - mergeOIDCConfigs(&l.Config) + if err = mergeAndValidateOIDCConfigs(&l.Config); err != nil { + return err + } // Now that all defaults are set and configurations are merged, validate all final settings return l.Config.ValidateAll() } -// mergeOIDCConfigs merges the OIDC overrides with the default OIDC configuration so that +// mergeAndValidateOIDCConfigs merges the OIDC overrides with the default OIDC configuration so that // all filters have only one location where the OIDC configuration is defined. -func mergeOIDCConfigs(cfg *configv1.Config) { +func mergeAndValidateOIDCConfigs(cfg *configv1.Config) error { for _, fc := range cfg.Chains { for _, f := range fc.Filters { // Merge the OIDC overrides and populate the normal OIDC field instead so that @@ -117,11 +117,19 @@ func mergeOIDCConfigs(cfg *configv1.Config) { proto.Merge(oidc, f.GetOidcOverride()) f.Type = &configv1.Filter_Oidc{Oidc: oidc} } + + if redisURL := f.GetOidc().GetRedisSessionStoreConfig().GetServerUri(); redisURL != "" { + if _, err := redis.ParseURL(redisURL); err != nil { + return fmt.Errorf("%w: invalid redis URL in chain %q", ErrInvalidRedisURL, fc.Name) + } + } } } // Clear the default config as it has already been merged. This way there is only one // location for the OIDC settings. cfg.DefaultOidcConfig = nil + + return nil } func ConfigToJSONString(c *configv1.Config) string { diff --git a/internal/config_test.go b/internal/config_test.go index b0b27e9..68e8d88 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -59,6 +59,7 @@ func TestValidateConfig(t *testing.T) { {"duplicate-oidc", "testdata/duplicate-oidc.json", errCheck{is: ErrDuplicateOIDCConfig}}, {"invalid-oidc-override", "testdata/invalid-oidc-override.json", errCheck{is: ErrInvalidOIDCOverride}}, {"multiple-oidc", "testdata/multiple-oidc.json", errCheck{is: ErrMultipleOIDCConfig}}, + {"invalid-redis", "testdata/invalid-redis.json", errCheck{is: ErrInvalidRedisURL}}, {"valid", "testdata/mock.json", errCheck{is: nil}}, } @@ -114,15 +115,16 @@ func TestLoadOIDC(t *testing.T) { { Type: &configv1.Filter_Oidc{ Oidc: &oidcv1.OIDCConfig{ - AuthorizationUri: "http://fake", - TokenUri: "http://fake", - CallbackUri: "http://fake", - JwksConfig: &oidcv1.OIDCConfig_Jwks{Jwks: "fake-jwks"}, - ClientId: "fake-client-id", - ClientSecret: "fake-client-secret", - CookieNamePrefix: "", - IdToken: &oidcv1.TokenConfig{Preamble: "Bearer", Header: "authorization"}, - ProxyUri: "http://fake", + AuthorizationUri: "http://fake", + TokenUri: "http://fake", + CallbackUri: "http://fake", + JwksConfig: &oidcv1.OIDCConfig_Jwks{Jwks: "fake-jwks"}, + ClientId: "fake-client-id", + ClientSecret: "fake-client-secret", + CookieNamePrefix: "", + IdToken: &oidcv1.TokenConfig{Preamble: "Bearer", Header: "authorization"}, + ProxyUri: "http://fake", + RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "redis://localhost:6379/0"}, }, }, }, diff --git a/internal/oidc/memory.go b/internal/oidc/memory.go index e02c051..78ff4a1 100644 --- a/internal/oidc/memory.go +++ b/internal/oidc/memory.go @@ -15,6 +15,7 @@ package oidc import ( + "context" "sync" "time" @@ -28,7 +29,7 @@ var _ SessionStore = (*memoryStore)(nil) // memoryStore is an in-memory implementation of the SessionStore interface. type memoryStore struct { log telemetry.Logger - clock Clock + clock *Clock absoluteSessionTimeout time.Duration idleSessionTimeout time.Duration @@ -37,7 +38,7 @@ type memoryStore struct { } // NewMemoryStore creates a new in-memory session store. -func NewMemoryStore(clock Clock, absoluteSessionTimeout, idleSessionTimeout time.Duration) SessionStore { +func NewMemoryStore(clock *Clock, absoluteSessionTimeout, idleSessionTimeout time.Duration) SessionStore { return &memoryStore{ log: internal.Logger(internal.Session).With("type", "memory"), clock: clock, @@ -47,45 +48,47 @@ func NewMemoryStore(clock Clock, absoluteSessionTimeout, idleSessionTimeout time } } -func (m *memoryStore) SetTokenResponse(sessionID string, tokenResponse *TokenResponse) { +func (m *memoryStore) SetTokenResponse(_ context.Context, sessionID string, tokenResponse *TokenResponse) error { m.set(sessionID, func(s *session) { s.tokenResponse = tokenResponse }) + return nil } -func (m *memoryStore) GetTokenResponse(sessionID string) *TokenResponse { +func (m *memoryStore) GetTokenResponse(_ context.Context, sessionID string) (*TokenResponse, error) { m.mu.Lock() defer m.mu.Unlock() s := m.sessions[sessionID] if s == nil { - return nil + return nil, nil } s.accessed = m.clock.Now() - return s.tokenResponse + return s.tokenResponse, nil } -func (m *memoryStore) SetAuthorizationState(sessionID string, authorizationState *AuthorizationState) { +func (m *memoryStore) SetAuthorizationState(_ context.Context, sessionID string, authorizationState *AuthorizationState) error { m.set(sessionID, func(s *session) { s.authorizationState = authorizationState }) + return nil } -func (m *memoryStore) GetAuthorizationState(sessionID string) *AuthorizationState { +func (m *memoryStore) GetAuthorizationState(_ context.Context, sessionID string) (*AuthorizationState, error) { m.mu.Lock() defer m.mu.Unlock() s := m.sessions[sessionID] if s == nil { - return nil + return nil, nil } s.accessed = m.clock.Now() - return s.authorizationState + return s.authorizationState, nil } -func (m *memoryStore) ClearAuthorizationState(sessionID string) { +func (m *memoryStore) ClearAuthorizationState(_ context.Context, sessionID string) error { m.mu.Lock() defer m.mu.Unlock() @@ -93,16 +96,20 @@ func (m *memoryStore) ClearAuthorizationState(sessionID string) { s.accessed = m.clock.Now() s.authorizationState = nil } + + return nil } -func (m *memoryStore) RemoveSession(sessionID string) { +func (m *memoryStore) RemoveSession(_ context.Context, sessionID string) error { m.mu.Lock() defer m.mu.Unlock() delete(m.sessions, sessionID) + + return nil } -func (m *memoryStore) RemoveAllExpired() { +func (m *memoryStore) RemoveAllExpired(context.Context) error { var ( earliestTimeAddedToKeep = m.clock.Now().Add(-m.absoluteSessionTimeout) earliestTimeIdleToKeep = m.clock.Now().Add(-m.idleSessionTimeout) @@ -121,6 +128,8 @@ func (m *memoryStore) RemoveAllExpired() { delete(m.sessions, sessionID) } } + + return nil } // set the given session with the given setter function and record the access time. diff --git a/internal/oidc/memory_test.go b/internal/oidc/memory_test.go index 29c4b1a..e617159 100644 --- a/internal/oidc/memory_test.go +++ b/internal/oidc/memory_test.go @@ -15,83 +15,97 @@ package oidc import ( + "context" "testing" "time" "github.com/stretchr/testify/require" ) -func TestTokenResponse(t *testing.T) { - m := NewMemoryStore(Clock{}, 0, 0).(*memoryStore) +func TestMemoryTokenResponse(t *testing.T) { + m := NewMemoryStore(&Clock{}, 0, 0).(*memoryStore) + ctx := context.Background() - require.Nil(t, m.GetTokenResponse("s1")) + tr, err := m.GetTokenResponse(ctx, "s1") + require.NoError(t, err) + require.Nil(t, tr) // Create a session and verify it's added and accessed time - tr := &TokenResponse{} - m.SetTokenResponse("s1", &TokenResponse{}) + tr = &TokenResponse{} + require.NoError(t, m.SetTokenResponse(ctx, "s1", &TokenResponse{})) require.Greater(t, m.sessions["s1"].added.Unix(), int64(0)) require.Equal(t, m.sessions["s1"].added, m.sessions["s1"].accessed) // Verify that the right token response is returned and the accessed time is updated - require.Equal(t, tr, m.GetTokenResponse("s1")) + got, err := m.GetTokenResponse(ctx, "s1") + require.NoError(t, err) + require.Equal(t, tr, got) require.True(t, m.sessions["s1"].accessed.After(m.sessions["s1"].added)) lastAccessed := m.sessions["s1"].accessed // Verify that updating the token response also updates the session access timestamp - m.SetTokenResponse("s1", &TokenResponse{}) + require.NoError(t, m.SetTokenResponse(ctx, "s1", &TokenResponse{})) require.True(t, m.sessions["s1"].accessed.After(lastAccessed)) } -func TestAuthorizationState(t *testing.T) { - m := NewMemoryStore(Clock{}, 0, 0).(*memoryStore) +func TestMemoryAuthorizationState(t *testing.T) { + m := NewMemoryStore(&Clock{}, 0, 0).(*memoryStore) + ctx := context.Background() - as := m.GetAuthorizationState("s1") + as, err := m.GetAuthorizationState(ctx, "s1") + require.NoError(t, err) require.Nil(t, as) // Create a session and verify it's added and accessed time as = &AuthorizationState{} - m.SetAuthorizationState("s1", as) + require.NoError(t, m.SetAuthorizationState(ctx, "s1", as)) require.Greater(t, m.sessions["s1"].added.Unix(), int64(0)) require.Equal(t, m.sessions["s1"].added, m.sessions["s1"].accessed) // Verify that the right state is returned and the accessed time is updated - require.Equal(t, as, m.GetAuthorizationState("s1")) + got, err := m.GetAuthorizationState(ctx, "s1") + require.NoError(t, err) + require.Equal(t, as, got) lastAccessed := m.sessions["s1"].accessed require.True(t, lastAccessed.After(m.sessions["s1"].added)) // Verify that updating the authz state also updates the session access timestamp - m.SetAuthorizationState("s1", &AuthorizationState{}) + require.NoError(t, m.SetAuthorizationState(ctx, "s1", &AuthorizationState{})) require.True(t, m.sessions["s1"].accessed.After(lastAccessed)) // Verify that clearing the authz state also updates the session access timestamp - m.ClearAuthorizationState("s1") - require.Nil(t, m.GetAuthorizationState("s1")) + require.NoError(t, m.ClearAuthorizationState(ctx, "s1")) + got, err = m.GetAuthorizationState(ctx, "s1") + require.NoError(t, err) + require.Nil(t, got) require.True(t, m.sessions["s1"].accessed.After(lastAccessed)) } -func TestRemoveResponse(t *testing.T) { - m := NewMemoryStore(Clock{}, 0, 0).(*memoryStore) +func TestMemoryRemoveResponse(t *testing.T) { + m := NewMemoryStore(&Clock{}, 0, 0).(*memoryStore) + ctx := context.Background() - m.SetTokenResponse("s1", &TokenResponse{}) + require.NoError(t, m.SetTokenResponse(ctx, "s1", &TokenResponse{})) require.NotNil(t, m.sessions["s1"]) - m.RemoveSession("s1") + require.NoError(t, m.RemoveSession(ctx, "s1")) require.Nil(t, m.sessions["s1"]) } -func TestRemoveAllExpired(t *testing.T) { - m := NewMemoryStore(Clock{}, 0, 0).(*memoryStore) +func TestMemoryRemoveAllExpired(t *testing.T) { + m := NewMemoryStore(&Clock{}, 0, 0).(*memoryStore) + ctx := context.Background() - m.SetTokenResponse("s1", &TokenResponse{}) - m.SetTokenResponse("s2", &TokenResponse{}) - m.SetTokenResponse("abs-expired", &TokenResponse{}) - m.SetTokenResponse("idle-expired", &TokenResponse{}) + require.NoError(t, m.SetTokenResponse(ctx, "s1", &TokenResponse{})) + require.NoError(t, m.SetTokenResponse(ctx, "s2", &TokenResponse{})) + require.NoError(t, m.SetTokenResponse(ctx, "abs-expired", &TokenResponse{})) + require.NoError(t, m.SetTokenResponse(ctx, "idle-expired", &TokenResponse{})) m.sessions["abs-expired"].added = time.Now().Add(-time.Hour) m.sessions["idle-expired"].accessed = time.Now().Add(-time.Hour) t.Run("no-expiration", func(t *testing.T) { - m.RemoveAllExpired() + require.NoError(t, m.RemoveAllExpired(ctx)) require.Len(t, m.sessions, 4) require.NotNil(t, m.sessions["s1"]) @@ -103,7 +117,7 @@ func TestRemoveAllExpired(t *testing.T) { t.Run("expiration", func(t *testing.T) { m.absoluteSessionTimeout = time.Minute m.idleSessionTimeout = time.Minute - m.RemoveAllExpired() + require.NoError(t, m.RemoveAllExpired(ctx)) require.Len(t, m.sessions, 2) require.NotNil(t, m.sessions["s1"]) diff --git a/internal/oidc/redis.go b/internal/oidc/redis.go index b65de68..68bec15 100644 --- a/internal/oidc/redis.go +++ b/internal/oidc/redis.go @@ -14,12 +14,193 @@ package oidc -var _ SessionStore = (*redisStore)(nil) +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/redis/go-redis/v9" + "github.com/tetratelabs/telemetry" + + "github.com/tetrateio/authservice-go/internal" +) + +var ( + _ SessionStore = (*redisStore)(nil) + + ErrRedis = errors.New("redis error") +) + +const ( + keyIDToken = "id_token" + keyAccessToken = "access_token" + keyAccessTokenExpiry = "access_token_expiry" + keyRefreshToken = "refresh_token" + keyState = "state" + keyNonce = "nonce" + keyRequestedURL = "requested_url" + keyTimeAdded = "time_added" +) + +var ( + tokenResponseKeys = []string{keyIDToken, keyAccessToken, keyRefreshToken, keyAccessTokenExpiry, keyTimeAdded} + // authorizationStateKeys = []string{keyState, keyNonce, keyRequestedURL, keyTimeAdded} +) // redisStore is an in-memory implementation of the SessionStore interface that stores // the session data in a given Redis server. type redisStore struct { - // TODO(nacx): Remove the interface embedding and implement it + //TODO(nacx): Remove this interface embedding when the interface is fully implemented SessionStore - url string + log telemetry.Logger + clock *Clock + client redis.Cmdable + absoluteSessionTimeout time.Duration + idleSessionTimeout time.Duration +} + +// NewRedisStore creates a new SessionStore that stores the session data in a given Redis server. +func NewRedisStore(clock *Clock, client redis.Cmdable, absoluteSessionTimeout, idleSessionTimeout time.Duration) (SessionStore, error) { + if err := client.Ping(context.TODO()).Err(); err != nil { + return nil, err + } + + return &redisStore{ + log: internal.Logger(internal.Session).With("type", "redis"), + clock: clock, + client: client, + absoluteSessionTimeout: absoluteSessionTimeout, + idleSessionTimeout: idleSessionTimeout, + }, nil +} + +func (r *redisStore) SetTokenResponse(ctx context.Context, sessionID string, tokenResponse *TokenResponse) error { + if err := r.client.HSet(ctx, sessionID, keyIDToken, tokenResponse.IDToken).Err(); err != nil { + return err + } + + var keysToDelete []string + + if tokenResponse.AccessToken != "" { + if err := r.client.HSet(ctx, sessionID, keyAccessToken, tokenResponse.AccessToken).Err(); err != nil { + return err + } + } else { + keysToDelete = append(keysToDelete, keyAccessToken) + } + + if !tokenResponse.AccessTokenExpiresAt.IsZero() { + if err := r.client.HSet(ctx, sessionID, keyAccessTokenExpiry, tokenResponse.AccessTokenExpiresAt).Err(); err != nil { + return err + } + } else { + keysToDelete = append(keysToDelete, keyAccessTokenExpiry) + } + + if tokenResponse.RefreshToken != "" { + if err := r.client.HSet(ctx, sessionID, keyRefreshToken, tokenResponse.RefreshToken).Err(); err != nil { + return err + } + } else { + keysToDelete = append(keysToDelete, keyRefreshToken) + } + + if len(keysToDelete) > 0 { + if err := r.client.HDel(ctx, sessionID, keysToDelete...).Err(); err != nil { + return err + } + } + + now := r.clock.Now() + if err := r.client.HSetNX(ctx, sessionID, keyTimeAdded, now).Err(); err != nil { + return err + } + + return r.refreshExpiration(ctx, sessionID, now) +} + +func (r *redisStore) GetTokenResponse(ctx context.Context, sessionID string) (*TokenResponse, error) { + log := r.log.Context(ctx) + + res := r.client.HMGet(ctx, sessionID, tokenResponseKeys...) + if res.Err() != nil { + return nil, res.Err() + } + + var token redisToken + if err := res.Scan(&token); err != nil { + return nil, err + } + + if token.IDToken == "" { + log.Debug("id token not found", "session_id", sessionID) + return nil, nil + } + + tokenResponse := token.TokenResponse() + if _, err := tokenResponse.ParseIDToken(); err != nil { + log.Error("failed to parse id token", err, "session_id", sessionID, "token", token) + return nil, nil + } + + if err := r.refreshExpiration(ctx, sessionID, token.TimeAdded); err != nil { + return nil, err + } + + return &tokenResponse, nil +} + +func (r *redisStore) refreshExpiration(ctx context.Context, sessionID string, timeAdded time.Time) error { + if timeAdded.IsZero() { + timeAdded, _ = r.client.HGet(ctx, sessionID, keyTimeAdded).Time() + } + + if timeAdded.IsZero() { + if err := r.client.Del(ctx, sessionID).Err(); err != nil { + return err + } + return fmt.Errorf("%w: session did not contain creation timestamp", ErrRedis) + } + + if r.absoluteSessionTimeout == 0 && r.idleSessionTimeout == 0 { + return nil + } + + var ( + now = r.clock.Now() + absoluteExpireAt = timeAdded.Add(r.absoluteSessionTimeout) + idleExpireAt = now.Add(r.idleSessionTimeout) + expireAt time.Time + ) + + if r.absoluteSessionTimeout == 0 { + expireAt = idleExpireAt + } else if r.idleSessionTimeout == 0 { + expireAt = absoluteExpireAt + } else { + expireAt = absoluteExpireAt + if idleExpireAt.Before(expireAt) { + expireAt = idleExpireAt + } + } + + return r.client.ExpireAt(ctx, sessionID, expireAt).Err() +} + +type redisToken struct { + IDToken string `redis:"id_token"` + AccessToken string `redis:"access_token"` + AccessTokenExpiresAt time.Time `redis:"access_token_expiry"` + RefreshToken string `redis:"refresh_token"` + TimeAdded time.Time `redis:"time_added"` +} + +func (r redisToken) TokenResponse() TokenResponse { + return TokenResponse{ + IDToken: r.IDToken, + AccessToken: r.AccessToken, + AccessTokenExpiresAt: r.AccessTokenExpiresAt, + RefreshToken: r.RefreshToken, + } } diff --git a/internal/oidc/redis_test.go b/internal/oidc/redis_test.go new file mode 100644 index 0000000..35d2fe1 --- /dev/null +++ b/internal/oidc/redis_test.go @@ -0,0 +1,147 @@ +// Copyright 2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package oidc + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func TestRedisTokenResponse(t *testing.T) { + mr := miniredis.RunT(t) + client := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + store, err := NewRedisStore(&Clock{}, client, 0, 1*time.Minute) + require.NoError(t, err) + + ctx := context.Background() + + tr, err := store.GetTokenResponse(ctx, "s1") + require.NoError(t, err) + require.Nil(t, tr) + + // Create a session and verify it's added and accessed time + tr = &TokenResponse{ + IDToken: newToken(), + AccessToken: newToken(), + AccessTokenExpiresAt: time.Now().Add(30 * time.Minute), + RefreshToken: newToken(), + } + require.NoError(t, store.SetTokenResponse(ctx, "s1", tr)) + + // Verify we can retrieve the token + got, err := store.GetTokenResponse(ctx, "s1") + require.NoError(t, err) + // The testify library doesn't properly compare times, so we need to do it manually + // then set the times in the returned object so that we can compare the rest of the + // fields normally + require.True(t, tr.AccessTokenExpiresAt.Equal(got.AccessTokenExpiresAt)) + got.AccessTokenExpiresAt = tr.AccessTokenExpiresAt + require.Equal(t, tr, got) + + // Verify that the token TTL has been set + added, _ := client.HGet(ctx, "s1", keyTimeAdded).Time() + ttl := client.TTL(ctx, "s1").Val() + require.Greater(t, added.Unix(), int64(0)) + require.Greater(t, ttl, time.Duration(0)) + + // Check keys are deleted + tr.AccessToken = "" + tr.RefreshToken = "" + tr.AccessTokenExpiresAt = time.Time{} + require.NoError(t, store.SetTokenResponse(ctx, "s1", tr)) + + var rt redisToken + vals := client.HMGet(ctx, "s1", keyAccessToken, keyRefreshToken, keyAccessTokenExpiry) + require.NoError(t, vals.Scan(&rt)) + require.Empty(t, rt.AccessToken) + require.True(t, rt.AccessTokenExpiresAt.IsZero()) + require.Empty(t, rt.RefreshToken) +} + +func TestRedisPingError(t *testing.T) { + mr := miniredis.RunT(t) + client := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + mr.SetError("ping error") + + _, err := NewRedisStore(&Clock{}, client, 0, 1*time.Minute) + require.EqualError(t, err, "ping error") +} + +func TestRefreshExpiration(t *testing.T) { + mr := miniredis.RunT(t) + client := redis.NewClient(&redis.Options{Addr: mr.Addr()}) + store, err := NewRedisStore(&Clock{}, client, 0, 0) + require.NoError(t, err) + rs := store.(*redisStore) + + ctx := context.Background() + + t.Run("delete session if no time added", func(t *testing.T) { + require.NoError(t, client.HSet(ctx, "s1", keyAccessToken, "").Err()) + err := rs.refreshExpiration(ctx, "s1", time.Time{}) + require.ErrorIs(t, err, ErrRedis) + require.Equal(t, redis.Nil, client.Get(ctx, "s1").Err()) + }) + + t.Run("no expiration set if no timeouts", func(t *testing.T) { + require.NoError(t, client.HSet(ctx, "s1", keyTimeAdded, time.Now()).Err()) + require.NoError(t, rs.refreshExpiration(ctx, "s1", time.Time{})) + + res, err := client.TTL(ctx, "s1").Result() + require.NoError(t, err) + require.Equal(t, time.Duration(-1), res) + }) + + t.Run("set idle expiration", func(t *testing.T) { + rs.absoluteSessionTimeout = 0 + rs.idleSessionTimeout = 1 * time.Minute + require.NoError(t, client.HSet(ctx, "s1", keyTimeAdded, time.Now()).Err()) + require.NoError(t, rs.refreshExpiration(ctx, "s1", time.Time{})) + + res, err := client.TTL(ctx, "s1").Result() + require.NoError(t, err) + require.Greater(t, res, time.Duration(0)) + require.LessOrEqual(t, res, rs.idleSessionTimeout) + }) + + t.Run("set absolute expiration", func(t *testing.T) { + rs.absoluteSessionTimeout = 30 * time.Second + rs.idleSessionTimeout = 0 + require.NoError(t, client.HSet(ctx, "s1", keyTimeAdded, time.Now()).Err()) + require.NoError(t, rs.refreshExpiration(ctx, "s1", time.Time{})) + + res, err := client.TTL(ctx, "s1").Result() + require.NoError(t, err) + require.Greater(t, res, time.Duration(0)) + require.LessOrEqual(t, res, rs.absoluteSessionTimeout) + }) + + t.Run("set smallest expiration", func(t *testing.T) { + rs.idleSessionTimeout = 10 * time.Second + rs.absoluteSessionTimeout = 20 * time.Second + require.NoError(t, client.HSet(ctx, "s1", keyTimeAdded, time.Now()).Err()) + require.NoError(t, rs.refreshExpiration(ctx, "s1", time.Time{})) + + res, err := client.TTL(ctx, "s1").Result() + require.NoError(t, err) + require.Greater(t, res, time.Duration(0)) + require.LessOrEqual(t, res, rs.idleSessionTimeout) + }) +} diff --git a/internal/oidc/session.go b/internal/oidc/session.go index 4013fec..160d685 100644 --- a/internal/oidc/session.go +++ b/internal/oidc/session.go @@ -15,8 +15,10 @@ package oidc import ( + "context" "time" + "github.com/redis/go-redis/v9" "github.com/tetratelabs/run" configv1 "github.com/tetrateio/authservice-go/config/gen/go/v1" @@ -25,13 +27,13 @@ import ( // SessionStore is an interface for storing session data. type SessionStore interface { - SetTokenResponse(sessionID string, tokenResponse *TokenResponse) - GetTokenResponse(sessionID string) *TokenResponse - SetAuthorizationState(sessionID string, authorizationState *AuthorizationState) - GetAuthorizationState(sessionID string) *AuthorizationState - ClearAuthorizationState(sessionID string) - RemoveSession(sessionID string) - RemoveAllExpired() + SetTokenResponse(ctx context.Context, sessionID string, tokenResponse *TokenResponse) error + GetTokenResponse(ctx context.Context, sessionID string) (*TokenResponse, error) + SetAuthorizationState(ctx context.Context, sessionID string, authorizationState *AuthorizationState) error + GetAuthorizationState(ctx context.Context, sessionID string) (*AuthorizationState, error) + ClearAuthorizationState(ctx context.Context, sessionID string) error + RemoveSession(ctx context.Context, sessionID string) error + RemoveAllExpired(ctx context.Context) error } var _ run.PreRunner = (*SessionStoreFactory)(nil) @@ -51,6 +53,7 @@ func (s *SessionStoreFactory) Name() string { return "OIDC session store factory // PreRun initializes the stores that are defined in the configuration func (s *SessionStoreFactory) PreRun() error { s.redis = make(map[string]SessionStore) + clock := &Clock{} for _, fc := range s.Config.Chains { for _, f := range fc.Filters { @@ -59,11 +62,19 @@ func (s *SessionStoreFactory) PreRun() error { } if redisServer := f.GetOidc().GetRedisSessionStoreConfig().GetServerUri(); redisServer != "" { - // TODO(nacx): Initialize the Redis store - s.redis[redisServer] = &redisStore{url: redisServer} + // No need to check the errors here as it has already been validated when loading the configuration + opts, _ := redis.ParseURL(redisServer) + client := redis.NewClient(opts) + r, err := NewRedisStore(clock, client, + time.Duration(f.GetOidc().GetAbsoluteSessionTimeout()), + time.Duration(f.GetOidc().GetIdleSessionTimeout()), + ) + if err != nil { + return err + } + s.redis[redisServer] = r } else if s.memory == nil { // Use a shared in-memory store for all OIDC configurations - s.memory = NewMemoryStore( - Clock{}, + s.memory = NewMemoryStore(clock, time.Duration(f.GetOidc().GetAbsoluteSessionTimeout()), time.Duration(f.GetOidc().GetIdleSessionTimeout()), ) diff --git a/internal/oidc/session_test.go b/internal/oidc/session_test.go index 0fee1b3..704a543 100644 --- a/internal/oidc/session_test.go +++ b/internal/oidc/session_test.go @@ -17,6 +17,8 @@ package oidc import ( "testing" + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/require" "github.com/tetratelabs/run" "github.com/tetratelabs/telemetry" @@ -27,6 +29,9 @@ import ( ) func TestSessionStoreFactory(t *testing.T) { + redis1 := miniredis.RunT(t) + redis2 := miniredis.RunT(t) + config := &configv1.Config{ ListenAddress: "0.0.0.0", ListenPort: 8080, @@ -52,7 +57,7 @@ func TestSessionStoreFactory(t *testing.T) { { Type: &configv1.Filter_Oidc{ Oidc: &oidcv1.OIDCConfig{ - RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "http://redis1:6379"}, + RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "redis://" + redis1.Addr()}, }, }, }, @@ -64,7 +69,7 @@ func TestSessionStoreFactory(t *testing.T) { { Type: &configv1.Filter_Oidc{ Oidc: &oidcv1.OIDCConfig{ - RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "http://redis2:6379"}, + RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "redis://" + redis2.Addr()}, }, }, }, @@ -85,6 +90,37 @@ func TestSessionStoreFactory(t *testing.T) { require.IsType(t, &memoryStore{}, store.Get(&oidcv1.OIDCConfig{})) require.IsType(t, &memoryStore{}, store.Get(config.Chains[0].Filters[1].GetOidc())) require.IsType(t, &memoryStore{}, store.Get(config.Chains[1].Filters[0].GetOidc())) - require.Equal(t, "http://redis1:6379", store.Get(config.Chains[2].Filters[0].GetOidc()).(*redisStore).url) - require.Equal(t, "http://redis2:6379", store.Get(config.Chains[3].Filters[0].GetOidc()).(*redisStore).url) + require.Equal(t, redis1.Addr(), store.Get(config.Chains[2].Filters[0].GetOidc()).(*redisStore).client.(*redis.Client).Options().Addr) + require.Equal(t, redis2.Addr(), store.Get(config.Chains[3].Filters[0].GetOidc()).(*redisStore).client.(*redis.Client).Options().Addr) +} + +func TestSessionStoreFactoryRedisFails(t *testing.T) { + mr := miniredis.RunT(t) + config := &configv1.Config{ + ListenAddress: "0.0.0.0", + ListenPort: 8080, + LogLevel: "debug", + Threads: 1, + Chains: []*configv1.FilterChain{ + { + Name: "redis", + Filters: []*configv1.Filter{ + { + Type: &configv1.Filter_Oidc{ + Oidc: &oidcv1.OIDCConfig{ + RedisSessionStoreConfig: &oidcv1.RedisConfig{ServerUri: "redis://" + mr.Addr()}, + }, + }, + }, + }, + }, + }, + } + + store := SessionStoreFactory{Config: config} + g := run.Group{Logger: telemetry.NoopLogger()} + g.Register(&store) + + mr.SetError("server error") + require.ErrorContains(t, g.Run(), "server error") } diff --git a/internal/oidc/token.go b/internal/oidc/token.go index 858ccf1..71a7871 100644 --- a/internal/oidc/token.go +++ b/internal/oidc/token.go @@ -14,12 +14,22 @@ package oidc -import "time" +import ( + "time" + + "github.com/lestrrat-go/jwx/jwt" +) // TokenResponse contains information about the tokens returned by the Identity Provider. type TokenResponse struct { - IDToken string - AccessToken string - RefreshToken string - AccessTokenExpiry time.Duration + IDToken string + AccessToken string + AccessTokenExpiresAt time.Time + RefreshToken string +} + +func (t *TokenResponse) ParseIDToken() (jwt.Token, error) { return parse(t.IDToken) } + +func parse(token string) (jwt.Token, error) { + return jwt.Parse([]byte(token), jwt.WithValidate(false)) } diff --git a/internal/oidc/token_test.go b/internal/oidc/token_test.go new file mode 100644 index 0000000..ad4d304 --- /dev/null +++ b/internal/oidc/token_test.go @@ -0,0 +1,52 @@ +// Copyright 2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package oidc + +import ( + "testing" + "time" + + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwt" + "github.com/stretchr/testify/require" +) + +func TestParseIDToken(t *testing.T) { + t.Run("valid", func(t *testing.T) { + tr := &TokenResponse{ + IDToken: newToken(), + } + + it, err := tr.ParseIDToken() + require.NoError(t, err) + require.Equal(t, "authservice", it.Issuer()) + }) + + t.Run("invalid", func(t *testing.T) { + tr := &TokenResponse{} + _, err := tr.ParseIDToken() + require.Error(t, err) + }) +} + +func newToken() string { + token, _ := jwt.NewBuilder(). + Issuer("authservice"). + Subject("user"). + Expiration(time.Now().Add(time.Hour)). + Build() + signed, _ := jwt.Sign(token, jwa.HS256, []byte("key")) + return string(signed) +} diff --git a/internal/server/logging.go b/internal/server/logging.go index ff71be7..de2e3fb 100644 --- a/internal/server/logging.go +++ b/internal/server/logging.go @@ -47,9 +47,9 @@ func (l LogMiddleware) UnaryServerInterceptor( ) (interface{}, error) { log := l.log.Context(ctx) - log.Debug("request", "side", "server", "method", info.FullMethod, "data", toJSON(req)) + log.Debug("request", "method", info.FullMethod, "data", toJSON(req)) resp, err := handler(ctx, req) - log.Debug("response", "side", "server", "method", info.FullMethod, "data", toJSON(req), "error", err) + log.Debug("response", "method", info.FullMethod, "data", toJSON(req), "error", err) return resp, err } diff --git a/internal/testdata/invalid-redis.json b/internal/testdata/invalid-redis.json new file mode 100644 index 0000000..14ce9bd --- /dev/null +++ b/internal/testdata/invalid-redis.json @@ -0,0 +1,30 @@ +{ + "listen_address": "0.0.0.0", + "listen_port": 8080, + "log_level": "debug", + "chains": [ + { + "name": "oidc", + "filters": [ + { + "oidc": { + "authorization_uri": "http://fake", + "token_uri": "http://fake", + "callback_uri": "http://fake", + "proxy_uri": "http://fake", + "jwks": "fake-jwks", + "client_id": "fake-client-id", + "client_secret": "fake-client-secret", + "id_token": { + "preamble": "Bearer", + "header": "authorization" + }, + "redis_session_store_config": { + "server_uri": "http://fake" + } + } + } + ] + } + ] +} diff --git a/internal/testdata/oidc-override.json b/internal/testdata/oidc-override.json index fec178c..cab1f05 100644 --- a/internal/testdata/oidc-override.json +++ b/internal/testdata/oidc-override.json @@ -22,6 +22,9 @@ "id_token": { "preamble": "Bearer", "header": "authorization" + }, + "redis_session_store_config": { + "server_uri": "redis://localhost:6379/0" } } } diff --git a/internal/testdata/oidc.json b/internal/testdata/oidc.json index 1f0c682..29918b8 100644 --- a/internal/testdata/oidc.json +++ b/internal/testdata/oidc.json @@ -18,6 +18,9 @@ "id_token": { "preamble": "Bearer", "header": "authorization" + }, + "redis_session_store_config": { + "server_uri": "redis://localhost:6379/0" } } }