diff --git a/internal/authz/oidc.go b/internal/authz/oidc.go index d7fe287..f175806 100644 --- a/internal/authz/oidc.go +++ b/internal/authz/oidc.go @@ -30,6 +30,7 @@ import ( corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "github.com/lestrrat-go/jwx/jws" "github.com/tetratelabs/telemetry" "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" @@ -375,37 +376,44 @@ func (o *oidcHandler) retrieveTokens(ctx context.Context, log telemetry.Logger, return } - if oidcNonce, ok := idToken.Get("nonce"); ok { - if oidcNonce.(string) != stateFromStore.Nonce { - log.Info("id token nonce does not match", "nonce-from-id-token", oidcNonce, "nonce-from-store", stateFromStore.Nonce) - setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) - return - } + oidcNonce, ok := idToken.Get("nonce") + if !ok { + log.Info("id token does not have nonce claim") + setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) + return } - var ( - audMatches bool - oidcAud interface{} - ok bool - ) - if oidcAud, ok = idToken.Get("aud"); ok { - switch aud := oidcAud.(type) { - case string: - audMatches = aud == o.config.GetClientId() - case []string: - for _, a := range aud { - if a == o.config.GetClientId() { - audMatches = true - break - } - } + if oidcNonce.(string) != stateFromStore.Nonce { + log.Info("id token nonce does not match", "nonce-from-id-token", oidcNonce, "nonce-from-store", stateFromStore.Nonce) + setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) + return + } + + var audMatches bool + for _, a := range idToken.Audience() { + if a == o.config.GetClientId() { + audMatches = true + break } } if !audMatches { - log.Info("id token audience does not match", "aud-from-id-token", oidcAud, "aud-from-config", o.config.GetClientId()) + log.Info("id token audience does not match", "aud-from-id-token", idToken.Audience(), "aud-from-config", o.config.GetClientId()) setDenyResponse(resp, newDenyResponse(), codes.InvalidArgument) return } + jwtSet, err := o.jwks.Get(ctx, o.config) + if err != nil { + log.Error("error fetching jwks", err) + setDenyResponse(resp, newDenyResponse(), codes.Internal) + return + } + + if _, err := jws.VerifySet([]byte(bodyTokens.IDToken), jwtSet); err != nil { + log.Error("error verifying id token with fetched jwks", err) + setDenyResponse(resp, newDenyResponse(), codes.Internal) + return + } + // https://openid.net/specs/openid-connect-core-1_0.html#TokenResponse // token_type must be Bearer if bodyTokens.TokenType != "Bearer" { diff --git a/internal/authz/oidc_test.go b/internal/authz/oidc_test.go index 664c05f..0718295 100644 --- a/internal/authz/oidc_test.go +++ b/internal/authz/oidc_test.go @@ -16,6 +16,8 @@ package authz import ( "context" + "crypto/rand" + "crypto/rsa" "encoding/base64" "encoding/json" "errors" @@ -29,6 +31,7 @@ import ( envoy "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwk" "github.com/lestrrat-go/jwx/jwt" "github.com/stretchr/testify/require" "github.com/tetratelabs/telemetry" @@ -150,6 +153,14 @@ func TestOIDCProcess(t *testing.T) { wantRedirectParams.Add("nonce", newNonce) wantRedirectBaseURI := "http://idp-test-server/auth" + unknownJWKPriv, _ := newKeyPair(t) + jwkPriv, jwkPub := newKeyPair(t) + bytes, err := json.Marshal(newKeySet(jwkPub)) + require.NoError(t, err) + basicOIDCConfig.JwksConfig = &oidcv1.OIDCConfig_Jwks{ + Jwks: string(bytes), + } + clock := oidc.Clock{} sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)} store := sessions.Get(basicOIDCConfig) @@ -204,7 +215,7 @@ func TestOIDCProcess(t *testing.T) { name: "request with an existing sessionID expired", req: withSessionHeader, storedTokenResponse: &oidc.TokenResponse{ - IDToken: newJWT(t, jwt.NewBuilder().Expiration(yesterday)), + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(yesterday)), AccessToken: "access-token", AccessTokenExpiresAt: yesterday, }, @@ -223,14 +234,14 @@ func TestOIDCProcess(t *testing.T) { name: "request with an existing sessionID not expired", req: withSessionHeader, storedTokenResponse: &oidc.TokenResponse{ - IDToken: newJWT(t, jwt.NewBuilder().Expiration(tomorrow)), + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(tomorrow)), AccessToken: "access-token", AccessTokenExpiresAt: tomorrow, }, responseVerify: func(t *testing.T, resp *envoy.CheckResponse) { require.Equal(t, int32(codes.OK), resp.GetStatus().GetCode()) require.NotNil(t, resp.GetOkResponse()) - requireTokensInResponse(t, resp.GetOkResponse(), basicOIDCConfig, newJWT(t, jwt.NewBuilder().Expiration(tomorrow)), "access-token") + requireTokensInResponse(t, resp.GetOkResponse(), basicOIDCConfig, newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(tomorrow)), "access-token") // The sessionID should not have been changed requireStoredTokens(t, store, sessionID, true) requireStoredState(t, store, newSessionID, false) @@ -274,7 +285,7 @@ func TestOIDCProcess(t *testing.T) { req: callbackRequest, storedAuthState: validAuthState, mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), AccessToken: "access-token", TokenType: "Bearer", }, @@ -383,6 +394,19 @@ func TestOIDCProcess(t *testing.T) { requireStoredTokens(t, store, sessionID, false) }, }, + { + name: "idp server returns JWT signed with unknown key", + req: callbackRequest, + storedAuthState: validAuthState, + mockTokensResponse: &tokensResponse{ + IDToken: newJWT(t, unknownJWKPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), + }, + responseVerify: func(t *testing.T, response *envoy.CheckResponse) { + require.Equal(t, int32(codes.Internal), response.GetStatus().GetCode()) + requireStandardResponseHeaders(t, response) + requireStoredTokens(t, store, sessionID, false) + }, + }, { name: "session nonce stored does idp returned nonce", req: callbackRequest, @@ -392,7 +416,7 @@ func TestOIDCProcess(t *testing.T) { RequestedURL: requestedAppURL, }, mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwt.NewBuilder().Claim("nonce", "non-matching-nonce")), + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", "non-matching-nonce")), }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) @@ -405,7 +429,7 @@ func TestOIDCProcess(t *testing.T) { req: callbackRequest, storedAuthState: validAuthState, mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwt.NewBuilder().Audience(nil)), + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce)), }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) @@ -418,7 +442,7 @@ func TestOIDCProcess(t *testing.T) { req: callbackRequest, storedAuthState: validAuthState, mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwt.NewBuilder().Audience([]string{"non-matching-audience"})), + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"non-matching-audience"})), }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { require.Equal(t, int32(codes.InvalidArgument), response.GetStatus().GetCode()) @@ -431,7 +455,7 @@ func TestOIDCProcess(t *testing.T) { req: callbackRequest, storedAuthState: validAuthState, mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwt.NewBuilder().Audience([]string{"test-client-id"})), + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"test-client-id"})), TokenType: "not-bearer", }, responseVerify: func(t *testing.T, response *envoy.CheckResponse) { @@ -445,7 +469,7 @@ func TestOIDCProcess(t *testing.T) { req: callbackRequest, storedAuthState: validAuthState, mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwt.NewBuilder().Audience([]string{"test-client-id"})), + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"test-client-id"})), TokenType: "Bearer", ExpiresIn: -1, }, @@ -460,7 +484,7 @@ func TestOIDCProcess(t *testing.T) { req: callbackRequest, storedAuthState: validAuthState, mockTokensResponse: &tokensResponse{ - IDToken: newJWT(t, jwt.NewBuilder().Audience([]string{"test-client-id"})), + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Claim("nonce", newNonce).Audience([]string{"test-client-id"})), TokenType: "Bearer", ExpiresIn: 3600, }, @@ -502,6 +526,13 @@ func TestOIDCProcessWithFailingSessionStore(t *testing.T) { store := &storeMock{delegate: oidc.NewMemoryStore(&oidc.Clock{}, time.Hour, time.Hour)} sessions := &mockSessionStoreFactory{store: store} + jwkPriv, jwkPub := newKeyPair(t) + bytes, err := json.Marshal(newKeySet(jwkPub)) + require.NoError(t, err) + basicOIDCConfig.JwksConfig = &oidcv1.OIDCConfig_Jwks{ + Jwks: string(bytes), + } + h, err := NewOIDCHandler(basicOIDCConfig, oidc.NewJWKSProvider(), sessions, oidc.Clock{}, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) require.NoError(t, err) @@ -540,7 +571,7 @@ func TestOIDCProcessWithFailingSessionStore(t *testing.T) { idpServer := newServer() idpServer.statusCode = http.StatusOK idpServer.tokensResponse = &tokensResponse{ - IDToken: newJWT(t, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), AccessToken: "access-token", TokenType: "Bearer", } @@ -583,6 +614,49 @@ func TestOIDCProcessWithFailingSessionStore(t *testing.T) { } } +func TestOIDCProcessWithFailingJWKSProvider(t *testing.T) { + funcJWKSProvider := jwksProviderFunc(func() (jwk.Set, error) { + return nil, errors.New("test jwks provider error") + }) + + jwkPriv, _ := newKeyPair(t) + + clock := oidc.Clock{} + sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)} + store := sessions.Get(basicOIDCConfig) + h, err := NewOIDCHandler(basicOIDCConfig, funcJWKSProvider, sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + require.NoError(t, err) + + idpServer := newServer() + h.(*oidcHandler).httpClient = idpServer.newHTTPClient() + + ctx := context.Background() + + idpServer.Start() + t.Cleanup(func() { + idpServer.Stop() + require.NoError(t, store.RemoveSession(ctx, sessionID)) + }) + + idpServer.tokensResponse = &tokensResponse{ + IDToken: newJWT(t, jwkPriv, jwt.NewBuilder().Audience([]string{"test-client-id"}).Claim("nonce", newNonce)), + AccessToken: "access-token", + TokenType: "Bearer", + } + idpServer.statusCode = http.StatusOK + + // Set the authorization state in the store, so it can be found by the handler + require.NoError(t, store.SetAuthorizationState(ctx, sessionID, validAuthState)) + + resp := &envoy.CheckResponse{} + err = h.Process(ctx, callbackRequest, resp) + require.NoError(t, err) + + require.Equal(t, int32(codes.Internal), resp.GetStatus().GetCode()) + requireStandardResponseHeaders(t, resp) + requireStoredTokens(t, store, sessionID, false) +} + func TestMatchesCallbackPath(t *testing.T) { tests := []struct { callback string @@ -708,6 +782,8 @@ func TestEncodeTokensToHeaders(t *testing.T) { } func TestAreTokensExpired(t *testing.T) { + jwkPriv, _ := newKeyPair(t) + tests := []struct { name string config *oidcv1.OIDCConfig @@ -718,40 +794,40 @@ func TestAreTokensExpired(t *testing.T) { { name: "no expiration - only id token", config: &oidcv1.OIDCConfig{}, - idToken: newJWT(t, jwt.NewBuilder().Expiration(tomorrow)), + idToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(tomorrow)), want: false, }, { name: "no expiration - id token and access token", config: &oidcv1.OIDCConfig{AccessToken: &oidcv1.TokenConfig{}}, - idToken: newJWT(t, jwt.NewBuilder().Expiration(tomorrow)), + idToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(tomorrow)), accessTokenExpiration: tomorrow, want: false, }, { name: "expired - only id token", config: &oidcv1.OIDCConfig{}, - idToken: newJWT(t, jwt.NewBuilder().Expiration(yesterday)), + idToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(yesterday)), want: true, }, { name: "expired - id token and access token", config: &oidcv1.OIDCConfig{AccessToken: &oidcv1.TokenConfig{}}, - idToken: newJWT(t, jwt.NewBuilder().Expiration(yesterday)), + idToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(yesterday)), accessTokenExpiration: yesterday, want: true, }, { name: "id token not expired, access token expired", config: &oidcv1.OIDCConfig{AccessToken: &oidcv1.TokenConfig{}}, - idToken: newJWT(t, jwt.NewBuilder().Expiration(tomorrow)), + idToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(tomorrow)), accessTokenExpiration: yesterday, want: true, }, { name: "id token not expired, access token expired - but access token not in config", config: &oidcv1.OIDCConfig{}, - idToken: newJWT(t, jwt.NewBuilder().Expiration(tomorrow)), + idToken: newJWT(t, jwkPriv, jwt.NewBuilder().Expiration(tomorrow)), accessTokenExpiration: yesterday, want: false, }, @@ -856,10 +932,41 @@ func modifyCallbackRequestPath(path string) *envoy.CheckRequest { } } -func newJWT(t *testing.T, builder *jwt.Builder) string { +const ( + keyID = "test" + keyAlg = jwa.RS256 +) + +func newKeySet(keys ...jwk.Key) jwk.Set { + jwks := jwk.NewSet() + for _, k := range keys { + jwks.Add(k) + } + return jwks +} + +func newKeyPair(t *testing.T) (jwk.Key, jwk.Key) { + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + priv, err := jwk.New(rsaKey) + require.NoError(t, err) + + pub, err := jwk.New(rsaKey.PublicKey) + require.NoError(t, err) + + err = pub.Set(jwk.KeyIDKey, keyID) + require.NoError(t, err) + err = pub.Set(jwk.AlgorithmKey, keyAlg) + require.NoError(t, err) + + return priv, pub +} + +func newJWT(t *testing.T, key jwk.Key, builder *jwt.Builder) string { token, err := builder.Build() require.NoError(t, err) - signed, err := jwt.Sign(token, jwa.HS256, []byte("key")) + signed, err := jwt.Sign(token, keyAlg, key) require.NoError(t, err) return string(signed) } @@ -1103,3 +1210,11 @@ type mockSessionStoreFactory struct { func (m mockSessionStoreFactory) Get(_ *oidcv1.OIDCConfig) oidc.SessionStore { return m.store } + +var _ oidc.JWKSProvider = jwksProviderFunc(nil) + +type jwksProviderFunc func() (jwk.Set, error) + +func (j jwksProviderFunc) Get(context.Context, *oidcv1.OIDCConfig) (jwk.Set, error) { + return j() +}