diff --git a/client.go b/client.go index 9b4ce02cb..427fae438 100644 --- a/client.go +++ b/client.go @@ -71,6 +71,17 @@ type OpenIDConnectClient interface { GetTokenEndpointAuthSigningAlgorithm() string } +// ClientWithAllowedVerificationKeys adds a security control to the client configuration to only allow +// specific verification keys. This ensures that a key that is valid for client X can't be used for client Y +// unless allowed. This becomes especially important for cases where the clients are controlled by third-parties +// and are issued specific keys from a central organization, which may be the OP's org or a central regulatory authority, +// and the security controls of the clients cannot be guaranteed. +type ClientWithAllowedVerificationKeys interface { + // AllowedVerificationKeys provides a list of key IDs that can be used in the JWT + // header for private_key_jwt authentication and for JWT bearer grant flow + AllowedVerificationKeys() []string +} + // ResponseModeClient represents a client capable of handling response_mode type ResponseModeClient interface { // GetResponseMode returns the response modes that client is allowed to send diff --git a/client_authentication.go b/client_authentication.go index 685e0311d..975b5b633 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -5,19 +5,13 @@ package fosite import ( "context" - "crypto/ecdsa" - "crypto/rsa" "encoding/json" - "fmt" "net/http" "net/url" "time" "github.com/ory/x/errorsx" - "github.com/go-jose/go-jose/v3" - "github.com/pkg/errors" - "github.com/ory/fosite/token/jwt" ) @@ -28,33 +22,15 @@ type ClientAuthenticationStrategy func(context.Context, *http.Request, url.Value const clientAssertionJWTBearerType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" func (f *Fosite) findClientPublicJWK(ctx context.Context, oidcClient OpenIDConnectClient, t *jwt.Token, expectsRSAKey bool) (interface{}, error) { - if set := oidcClient.GetJSONWebKeys(); set != nil { - return findPublicKey(t, set, expectsRSAKey) - } - - if location := oidcClient.GetJSONWebKeysURI(); len(location) > 0 { - keys, err := f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, false) - if err != nil { - return nil, err - } - - if key, err := findPublicKey(t, keys, expectsRSAKey); err == nil { - return key, nil - } - - keys, err = f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, true) - if err != nil { - return nil, err - } - - return findPublicKey(t, keys, expectsRSAKey) + if oidcClient.GetJSONWebKeys() == nil && oidcClient.GetJSONWebKeysURI() == "" { + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.")) } - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request.")) + return findPublicJWK(ctx, f.Config, t, oidcClient.GetJSONWebKeysURI(), oidcClient.GetJSONWebKeys(), expectsRSAKey, ErrInvalidClient) } // AuthenticateClient authenticates client requests using the configured strategy -// `Fosite.ClientAuthenticationStrategy`, if nil it uses `Fosite.DefaultClientAuthenticationStrategy` +// `ClientAuthenticationStrategy`, if nil it uses `DefaultClientAuthenticationStrategy` func (f *Fosite) AuthenticateClient(ctx context.Context, r *http.Request, form url.Values) (Client, error) { if s := f.Config.GetClientAuthenticationStrategy(ctx); s != nil { return s(ctx, r, form) @@ -71,81 +47,70 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The client_assertion request parameter must be set when using client_assertion_type of '%s'.", clientAssertionJWTBearerType)) } - var clientID string - var client Client - - token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (interface{}, error) { - var err error - clientID, _, err = clientCredentialsFromRequestBody(form, false) - if err != nil { - return nil, err - } + // Parse the assertion + token, parsedToken, isJWE, err := newToken(assertion, "client_assertion", ErrInvalidClient) + if err != nil { + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to parse the client_assertion").WithWrap(err).WithDebug(err.Error())) + } - if clientID == "" { - claims := t.Claims - if sub, ok := claims["sub"].(string); !ok { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the client_assertion JSON Web Token is undefined.")) - } else { - clientID = sub - } - } + claims := token.Claims - client, err = f.Store.GetClient(ctx, clientID) - if err != nil { - return nil, errorsx.WithStack(ErrInvalidClient.WithWrap(err).WithDebug(err.Error())) - } + // Validate client + clientID, _, err := clientCredentialsFromRequestBody(form, false) + if err != nil { + return nil, err + } - oidcClient, ok := client.(OpenIDConnectClient) - if !ok { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("The server configuration does not support OpenID Connect specific authentication methods.")) + if clientID == "" { + if isJWE { + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The 'client_id' must be part of the request when encrypted client_assertion is used.")) } - switch oidcClient.GetTokenEndpointAuthMethod() { - case "private_key_jwt": - break - case "none": - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This requested OAuth 2.0 client does not support client authentication, however 'client_assertion' was provided in the request.")) - case "client_secret_post": - fallthrough - case "client_secret_basic": - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however 'client_assertion' was provided in the request.", oidcClient.GetTokenEndpointAuthMethod())) - case "client_secret_jwt": - fallthrough - default: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however that method is not supported by this server.", oidcClient.GetTokenEndpointAuthMethod())) + if sub, ok := claims["sub"].(string); !ok { + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the client_assertion JSON Web Token is undefined.")) + } else { + clientID = sub } + } - if oidcClient.GetTokenEndpointAuthSigningAlgorithm() != fmt.Sprintf("%s", t.Header["alg"]) { - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' uses signing algorithm '%s' but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", t.Header["alg"], oidcClient.GetTokenEndpointAuthSigningAlgorithm())) - } - switch t.Method { - case jose.RS256, jose.RS384, jose.RS512: - return f.findClientPublicJWK(ctx, oidcClient, t, true) - case jose.ES256, jose.ES384, jose.ES512: - return f.findClientPublicJWK(ctx, oidcClient, t, false) - case jose.PS256, jose.PS384, jose.PS512: - return f.findClientPublicJWK(ctx, oidcClient, t, true) - case jose.HS256, jose.HS384, jose.HS512: - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This authorization server does not support client authentication method 'client_secret_jwt'.")) - default: - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' request parameter uses unsupported signing algorithm '%s'.", t.Header["alg"])) - } - }) + client, err := f.Store.GetClient(ctx, clientID) if err != nil { - // Do not re-process already enhanced errors - var e *jwt.ValidationError - if errors.As(err, &e) { - if e.Inner != nil { - return nil, e.Inner - } - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the integrity of the 'client_assertion' value.").WithWrap(err).WithDebug(err.Error())) - } + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client could not be authenticated.").WithWrap(err).WithDebug(err.Error())) + } + + oidcClient, ok := client.(OpenIDConnectClient) + if !ok { + return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("The server configuration does not support OpenID Connect specific authentication methods.")) + } + + switch oidcClient.GetTokenEndpointAuthMethod() { + case "private_key_jwt": + break + case "none": + return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This requested OAuth 2.0 client does not support client authentication, however 'client_assertion' was provided in the request.")) + case "client_secret_post": + fallthrough + case "client_secret_basic": + return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however 'client_assertion' was provided in the request.", oidcClient.GetTokenEndpointAuthMethod())) + case "client_secret_jwt": + fallthrough + default: + return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however that method is not supported by this server.", oidcClient.GetTokenEndpointAuthMethod())) + } + + // Validate signature + if !isJWE && oidcClient.GetTokenEndpointAuthSigningAlgorithm() != "" && oidcClient.GetTokenEndpointAuthSigningAlgorithm() != parsedToken.Headers[0].Algorithm { + return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The client_assertion uses signing algorithm '%s', but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", parsedToken.Headers[0].Algorithm, oidcClient.GetTokenEndpointAuthSigningAlgorithm())) + } + + ctx = context.WithValue(ctx, AssertionTypeContextKey, "client_assertion") + ctx = context.WithValue(ctx, BaseErrorContextKey, ErrInvalidClient) + if token, parsedToken, err = ValidateParsedAssertionWithClient(ctx, f.Config, assertion, token, parsedToken, oidcClient, false); err != nil { return nil, err - } else if err := token.Claims.Valid(); err != nil { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithWrap(err).WithDebug(err.Error())) } - claims := token.Claims + claims = token.Claims + var jti string if !claims.VerifyIssuer(clientID, true) { return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) @@ -255,43 +220,6 @@ func (f *Fosite) checkClientSecret(ctx context.Context, client Client, clientSec return err } -func findPublicKey(t *jwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool) (interface{}, error) { - keys := set.Keys - if len(keys) == 0 { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The retrieved JSON Web Key Set does not contain any key.")) - } - - kid, ok := t.Header["kid"].(string) - if ok { - keys = set.Key(kid) - } - - if len(keys) == 0 { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The JSON Web Token uses signing key with kid '%s', which could not be found.", kid)) - } - - for _, key := range keys { - if key.Use != "sig" { - continue - } - if expectsRSAKey { - if k, ok := key.Key.(*rsa.PublicKey); ok { - return k, nil - } - } else { - if k, ok := key.Key.(*ecdsa.PublicKey); ok { - return k, nil - } - } - } - - if expectsRSAKey { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find RSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) - } else { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find ECDSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) - } -} - func clientCredentialsFromRequest(r *http.Request, form url.Values) (clientID, clientSecret string, err error) { if id, secret, ok := r.BasicAuth(); !ok { return clientCredentialsFromRequestBody(form, true) diff --git a/client_authentication_test.go b/client_authentication_test.go index c93073ddc..8dabd83e9 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -31,6 +31,31 @@ import ( "github.com/ory/fosite/storage" ) +func encryptAssertionWithRSAKey(t *testing.T, token string, pubKey *rsa.PublicKey) string { + eo := &jose.EncrypterOptions{} + eo = eo.WithContentType("JWT").WithType("JWT") + enc, err := jose.NewEncrypter( + jose.ContentEncryption("A256GCM"), + jose.Recipient{ + Algorithm: jose.KeyAlgorithm("RSA-OAEP"), + Key: pubKey, + KeyID: "enc_key", + }, + eo) + + require.NoError(t, err, "unable to build encrypter; err=%v", err) + + // Encrypt the token + o, err := enc.Encrypt([]byte(token)) + require.NoError(t, err, "encrypting the token failed. err=%v", err) + + // Serialize the encrypted token + token, err = o.CompactSerialize() + require.NoError(t, err, "serializing the encrypted token failed. err=%v", err) + + return token +} + func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { token := jwt.NewWithClaims(jose.RS256, claims) token.Header["kid"] = kid @@ -75,14 +100,22 @@ func TestAuthenticateClient(t *testing.T) { const at = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" hasher := &BCrypt{Config: &Config{HashCost: 6}} + encKey := gen.MustRSAKey() + + config := &Config{ + JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), + ClientSecretsHasher: hasher, + TokenURL: "token-url", + HTTPClient: retryablehttp.NewClient(), + JWTStrategy: jwt.NewDefaultStrategy( + func(ctx context.Context, context *jwt.KeyContext) (interface{}, error) { + return encKey, nil + }), + } + f := &Fosite{ - Store: storage.NewMemoryStore(), - Config: &Config{ - JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(), - ClientSecretsHasher: hasher, - TokenURL: "token-url", - HTTPClient: retryablehttp.NewClient(), - }, + Store: storage.NewMemoryStore(), + Config: config, } barSecret, err := hasher.Hash(context.TODO(), []byte("bar")) @@ -300,6 +333,26 @@ func TestAuthenticateClient(t *testing.T) { }, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}}, r: new(http.Request), }, + { + d: "should pass with proper encrypted RSA assertion when JWKs are set within the client and client_id is set in the request", + client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"}, + form: url.Values{ + "client_id": []string{"bar"}, + "client_assertion": { + encryptAssertionWithRSAKey(t, + mustGenerateRSAAssertion(t, jwt.MapClaims{ + "sub": "bar", + "exp": time.Now().Add(time.Hour).Unix(), + "iss": "bar", + "jti": "12345", + "aud": "token-url", + }, rsaKey, "kid-foo"), + &encKey.PublicKey), + }, + "client_assertion_type": []string{at}, + }, + r: new(http.Request), + }, { d: "should pass with proper ECDSA assertion when JWKs are set within the client and client_id is not set in the request", client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: ecdsaJwks, TokenEndpointAuthMethod: "private_key_jwt", TokenEndpointAuthSigningAlgorithm: "ES256"}, diff --git a/config.go b/config.go index 1b50eb70c..aed2374a1 100644 --- a/config.go +++ b/config.go @@ -300,3 +300,9 @@ type PushedAuthorizeRequestConfigProvider interface { // must contain the PAR request_uri. EnforcePushedAuthorize(ctx context.Context) bool } + +// JWTStrategyProvider returns the provider for configuring the JWT strategy. +type JWTStrategyProvider interface { + // GetJWTStrategy returns the JWT strategy. + GetJWTStrategy(ctx context.Context) jwt.Strategy +} diff --git a/config_default.go b/config_default.go index 7f2e2487e..8a3bd7aa1 100644 --- a/config_default.go +++ b/config_default.go @@ -62,6 +62,7 @@ var ( _ RevocationHandlersProvider = (*Config)(nil) _ PushedAuthorizeRequestHandlersProvider = (*Config)(nil) _ PushedAuthorizeRequestConfigProvider = (*Config)(nil) + _ JWTStrategyProvider = (*Config)(nil) ) type Config struct { @@ -212,6 +213,9 @@ type Config struct { // IsPushedAuthorizeEnforced enforces pushed authorization request for /authorize IsPushedAuthorizeEnforced bool + + // JWTStrategy is used to provide additional JWT encrypt/decrypt/sign/verify capabilities + JWTStrategy jwt.Strategy } func (c *Config) GetGlobalSecret(ctx context.Context) ([]byte, error) { @@ -488,3 +492,8 @@ func (c *Config) GetPushedAuthorizeContextLifespan(ctx context.Context) time.Dur func (c *Config) EnforcePushedAuthorize(ctx context.Context) bool { return c.IsPushedAuthorizeEnforced } + +// GetJWTStrategy returns the JWT strategy. +func (c *Config) GetJWTStrategy(ctx context.Context) jwt.Strategy { + return c.JWTStrategy +} diff --git a/context.go b/context.go index d8b2bc3fd..3d72a3a73 100644 --- a/context.go +++ b/context.go @@ -19,4 +19,7 @@ const ( AuthorizeResponseContextKey = ContextKey("authorizeResponse") // PushedAuthorizeResponseContextKey is the response context PushedAuthorizeResponseContextKey = ContextKey("pushedAuthorizeResponse") + + AssertionTypeContextKey = ContextKey("assertionType") + BaseErrorContextKey = ContextKey("baseError") ) diff --git a/fosite_jwt.go b/fosite_jwt.go new file mode 100644 index 000000000..cd0526cf3 --- /dev/null +++ b/fosite_jwt.go @@ -0,0 +1,355 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package fosite + +import ( + "context" + "crypto/ecdsa" + "crypto/rsa" + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" + fjwt "github.com/ory/fosite/token/jwt" + "github.com/ory/x/errorsx" +) + +// JWTValidationConfig provides configuration to validate JWTs +type JWTValidationConfig struct { + // AllowedSigningKeys are the key IDs that are allowed to verify a JWT + AllowedSigningKeys []string `json:"kids"` + + // AllowedSigningAlgs are the algorithms allowed for a signed JWT + AllowedSigningAlgs []string `json:"algs"` + + // JSONWebKeysURI is the remote URI from which the JWKS is fetched + JSONWebKeysURI string `json:"jwks_uri"` + + // JSONWebKeys in place of JSONWebKeyURI + JSONWebKeys *jose.JSONWebKeySet `json:"jwks"` + + // NoneAlgAllowed indicates if the signing algorithm can be "none" + NoneAlgAllowed bool `json:"none"` +} + +// ValidateParsedAssertionWithClient validates the parsed assertion based on the jwks_uri, jwks etc. configured on the client +func ValidateParsedAssertionWithClient(ctx context.Context, config Configurator, assertion string, token *fjwt.Token, parsedToken *jwt.JSONWebToken, oidcClient OpenIDConnectClient, isNoneAlgAllowed bool) ( + *fjwt.Token, *jwt.JSONWebToken, error) { + + jwksURI := oidcClient.GetJSONWebKeysURI() + jwks := oidcClient.GetJSONWebKeys() + allowedKeys := []string{} + if c, ok := oidcClient.(ClientWithAllowedVerificationKeys); ok { + allowedKeys = c.AllowedVerificationKeys() + } + + return ValidateParsedAssertion(ctx, config, assertion, token, parsedToken, &JWTValidationConfig{ + AllowedSigningKeys: allowedKeys, + AllowedSigningAlgs: []string{oidcClient.GetTokenEndpointAuthSigningAlgorithm()}, + JSONWebKeysURI: jwksURI, + JSONWebKeys: jwks, + NoneAlgAllowed: isNoneAlgAllowed, + }) +} + +// ValidateParsedAssertion validates the parsed assertion based on the jwks_uri, jwks etc. that is passed in +func ValidateParsedAssertion(ctx context.Context, config Configurator, assertion string, token *fjwt.Token, parsedToken *jwt.JSONWebToken, verificationConfig *JWTValidationConfig) ( + *fjwt.Token, *jwt.JSONWebToken, error) { + + var err error + baseError := getBaseError(ctx) + assertionType := getAssertionType(ctx) + + var jwtStrategy fjwt.Strategy + if c, ok := config.(JWTStrategyProvider); ok { + jwtStrategy = c.GetJWTStrategy(ctx) + } + + if jwtStrategy != nil && len(token.Method) == 0 { // JWE + alg, _ := token.Header["alg"].(string) + enc, _ := token.Header["enc"].(string) + assertion, err = jwtStrategy.DecryptWithSettings(ctx, + &fjwt.KeyContext{ + EncryptionKeyID: parsedToken.Headers[0].KeyID, + EncryptionAlgorithm: alg, + EncryptionContentAlgorithm: enc, + }, + assertion) + if err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + var mapClaims fjwt.MapClaims = fjwt.MapClaims{} + + if cty, ok := token.Header["cty"].(string); ok && strings.ToUpper(cty) == "JWT" { // Nested JWT + + parsedToken, err = jwt.ParseSigned(assertion) + if err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + if err := parsedToken.UnsafeClaimsWithoutVerification(&mapClaims); err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + token.Claims = mapClaims + token.Method = jose.SignatureAlgorithm(parsedToken.Headers[0].Algorithm) + token.Header["kid"] = parsedToken.Headers[0].KeyID // When using jwks, the `kid` is read from token object + + } else { // Only encrypted, not signed + if err := json.Unmarshal([]byte(assertion), &mapClaims); err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + token.Claims = mapClaims + err = validateJWTClaims(ctx, mapClaims, assertionType, baseError) + if err != nil { + return nil, nil, err + } + + return token, parsedToken, nil + } + } + + if token.Method == fjwt.SigningMethodNone { + if !verificationConfig.NoneAlgAllowed { + return nil, nil, errorsx.WithStack(baseError.WithHintf("'none' is disallowed as a signing method of the '%s'.", assertionType)) + } + + return token, parsedToken, nil + } + + claims := token.Claims + signingAlg := parsedToken.Headers[0].Algorithm + if len(verificationConfig.AllowedSigningAlgs) > 0 && !Arguments(verificationConfig.AllowedSigningAlgs).Has(signingAlg) { + return nil, nil, errorsx.WithStack(baseError.WithHintf("The 'alg' used in the '%s' is not allowed.", assertionType)) + } + + signingKey := parsedToken.Headers[0].KeyID + if len(verificationConfig.AllowedSigningKeys) > 0 && !Arguments(verificationConfig.AllowedSigningKeys).Has(signingKey) { + return nil, nil, errorsx.WithStack(baseError.WithHintf("The 'kid' used in the '%s' is not allowed.", assertionType)) + } + + // Validate signature + if verificationConfig.JSONWebKeysURI == "" && verificationConfig.JSONWebKeys == nil { + if jwtStrategy == nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + _, err := jwtStrategy.ValidateWithSettings(ctx, + &fjwt.KeyContext{ + SigningKeyID: parsedToken.Headers[0].KeyID, + SigningAlgorithm: parsedToken.Headers[0].Algorithm, + }, + assertion) + if err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + } else { + var key interface{} + var err error + switch token.Method { + case jose.RS256, jose.RS384, jose.RS512: + key, err = findPublicJWK(ctx, config, token, verificationConfig.JSONWebKeysURI, verificationConfig.JSONWebKeys, true, baseError) + if err != nil { + return nil, nil, wrapSigningKeyFailure( + baseError.WithHint("Unable to retrieve RSA signing key from the JSON Web Key Set."), err) + } + case jose.ES256, jose.ES384, jose.ES512: + key, err = findPublicJWK(ctx, config, token, verificationConfig.JSONWebKeysURI, verificationConfig.JSONWebKeys, false, baseError) + if err != nil { + return nil, nil, wrapSigningKeyFailure( + baseError.WithHint("Unable to retrieve ECDSA signing key from the JSON Web Key Set."), err) + } + case jose.PS256, jose.PS384, jose.PS512: + key, err = findPublicJWK(ctx, config, token, verificationConfig.JSONWebKeysURI, verificationConfig.JSONWebKeys, true, baseError) + if err != nil { + return nil, nil, wrapSigningKeyFailure( + baseError.WithHint("Unable to retrieve RSA signing key from the JSON Web Key Set."), err) + } + default: + return nil, nil, errorsx.WithStack(baseError.WithHintf("The '%s' uses unsupported signing algorithm '%s'.", assertionType, token.Method)) + } + + // To verify signature go-jose requires a pointer to + // public key instead of the public key value. + // The pointer values provides that pointer. + // E.g. transform rsa.PublicKey -> *rsa.PublicKey + key = pointer(key) + + // verify signature with returned key + if err := parsedToken.Claims(key, &claims); err != nil { + return nil, nil, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + } + + err = validateJWTClaims(ctx, claims, assertionType, baseError) + if err != nil { + return nil, nil, err + } + + return token, parsedToken, nil +} + +func validateJWTClaims(ctx context.Context, claims fjwt.MapClaims, assertionType string, baseError *RFC6749Error) error { + // Validate claims + // This validation is performed to be backwards compatible + // with jwt-go library behavior + if err := claims.Valid(); err != nil { + if e, ok := err.(*fjwt.ValidationError); ok { + // return a more precise error + if e.Has(fjwt.ValidationErrorExpired) { + return errorsx.WithStack(baseError.WithHintf("The '%s' has expired.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + if e.Has(fjwt.ValidationErrorIssuedAt) { + return errorsx.WithStack(baseError.WithHintf("The 'iat' claim in '%s' is in the future.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + if e.Has(fjwt.ValidationErrorNotValidYet) { + return errorsx.WithStack(baseError.WithHintf("The '%s' is not valid yet.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + } + + return errorsx.WithStack(baseError.WithHintf("Invalid claims in the '%s'.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + return nil +} + +func newToken(assertion string, assertionType string, baseError *RFC6749Error) (*fjwt.Token, *jwt.JSONWebToken, bool, error) { + var err error + var parsedToken *jwt.JSONWebToken + + isJWE := false // assume it's signed + parsedToken, err = jwt.ParseSigned(assertion) + if err != nil { + parsedToken, err = jwt.ParseEncrypted(assertion) // probably it's encrypted + if err != nil { + return nil, nil, false, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + + isJWE = true + } + + token := &fjwt.Token{ + Header: map[string]interface{}{}, + Method: "", + } + + if !isJWE { + var claims fjwt.MapClaims = fjwt.MapClaims{} + if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil { + return nil, nil, false, errorsx.WithStack(baseError.WithHintf("Unable to verify the integrity of the '%s' value.", assertionType).WithWrap(err).WithDebug(err.Error())) + } + token.Claims = claims + } + + if len(parsedToken.Headers) != 1 { + return nil, nil, false, errorsx.WithStack(baseError.WithHintf("The '%s' value is expected to contain only one header.", assertionType)) + } + + // copy headers + h := parsedToken.Headers[0] + token.Header["alg"] = h.Algorithm + if h.KeyID != "" { + token.Header["kid"] = h.KeyID + } + for k, v := range h.ExtraHeaders { + token.Header[string(k)] = v + } + + if !isJWE { + token.Method = jose.SignatureAlgorithm(h.Algorithm) + } + + return token, parsedToken, isJWE, nil +} + +func findPublicKey(t *fjwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool, baseError *RFC6749Error) (interface{}, error) { + keys := set.Keys + if len(keys) == 0 { + return nil, errorsx.WithStack(baseError.WithHint("The retrieved JSON Web Key Set does not contain any keys")) + } + + kid, ok := t.Header["kid"].(string) + if ok { + keys = set.Key(kid) + } + + if len(keys) == 0 { + return nil, errorsx.WithStack(baseError.WithHintf("The JSON Web Token uses signing key with kid '%s', which could not be found.", kid)) + } + + for _, key := range keys { + if key.Use != "sig" { + continue + } + if expectsRSAKey { + if k, ok := key.Key.(*rsa.PublicKey); ok { + return k, nil + } + } else { + if k, ok := key.Key.(*ecdsa.PublicKey); ok { + return k, nil + } + } + } + + if expectsRSAKey { + return nil, errorsx.WithStack(baseError.WithHintf("Unable to find RSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) + } + + return nil, errorsx.WithStack(baseError.WithHintf("Unable to find ECDSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid)) +} + +func findPublicJWK(ctx context.Context, config Configurator, t *fjwt.Token, jwksURI string, jwks *jose.JSONWebKeySet, expectsRSAKey bool, baseError *RFC6749Error) (interface{}, error) { + if jwks != nil { + return findPublicKey(t, jwks, expectsRSAKey, baseError) + } + + keys, err := config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, jwksURI, false) + if err != nil { + return nil, err + } + + if key, err := findPublicKey(t, keys, expectsRSAKey, baseError); err == nil { + return key, nil + } + + keys, err = config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, jwksURI, true) + if err != nil { + return nil, errorsx.WithStack(baseError.WithHintf(fmt.Sprintf("%s", err))) + } + + return findPublicKey(t, keys, expectsRSAKey, baseError) +} + +// if underline value of v is not a pointer +// it creates a pointer of it and returns it +func pointer(v interface{}) interface{} { + if reflect.ValueOf(v).Kind() != reflect.Ptr { + value := reflect.New(reflect.ValueOf(v).Type()) + value.Elem().Set(reflect.ValueOf(v)) + return value.Interface() + } + return v +} + +func getBaseError(ctx context.Context) *RFC6749Error { + if e, ok := ctx.Value(BaseErrorContextKey).(*RFC6749Error); ok { + return e + } + + return ErrInvalidClient +} + +func getAssertionType(ctx context.Context) string { + if at, ok := ctx.Value(AssertionTypeContextKey).(string); ok { + return at + } + + return "assertion" +} diff --git a/token/jwt/strategy_jwt.go b/token/jwt/strategy_jwt.go new file mode 100644 index 000000000..070131abb --- /dev/null +++ b/token/jwt/strategy_jwt.go @@ -0,0 +1,180 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package jwt + +import ( + "context" + "crypto/ecdsa" + "crypto/rsa" + "fmt" + + "github.com/go-jose/go-jose/v3" +) + +// KeyContext contains context that is used to sign, validation, encrypt and decrypt tokens. +// It is populated in different ways depending on the operation. For example - +// +// 1. Validate : the SigningKeyID and SigningAlgorithm is based on the JWT header of the incoming token +// 2. Decrypt : the EncryptionKeyID, EncryptionAlgorithm and EncryptionContentAlgorithm is based on the JWT header of the incoming token +// 3. Generate : all the properties may be populated. The JWT strategy implementation may sign the token, then optionally encrypt it +type KeyContext struct { + SigningKeyID string + SigningAlgorithm string + EncryptionKeyID string + EncryptionAlgorithm string + EncryptionContentAlgorithm string + Extra map[string]interface{} +} + +// Strategy provides the overall strategy interface to sign (generate), encrypt (part of generate), decrypt and validate JWTs. +type Strategy interface { + Signer + + // GenerateWithSettings signs and optionally encrypts the token based on the context provided + GenerateWithSettings(ctx context.Context, settings *KeyContext, claims MapClaims, header Mapper) (string, string, error) + + // DecryptWithSettings decrypts the token provided. If the token is not encrypted, the function should return an error. + DecryptWithSettings(ctx context.Context, settings *KeyContext, token string) (string, error) + + // ValidateWithSettings validates the signed token. If the token is not signed, the function should return an error. + ValidateWithSettings(ctx context.Context, settings *KeyContext, token string) (string, error) +} + +type GetPrivateKeyWithContextFunc func(ctx context.Context, context *KeyContext) (interface{}, error) + +// DefaultStrategy is responsible for generating (signing and optionally encrypting), decrypting and validating JWT challenges +type DefaultStrategy struct { + *DefaultSigner + GetPrivateKey GetPrivateKeyWithContextFunc +} + +func NewDefaultStrategy(GetPrivateKey GetPrivateKeyWithContextFunc) Strategy { + return &DefaultStrategy{ + DefaultSigner: &DefaultSigner{ + GetPrivateKey: func(ctx context.Context) (interface{}, error) { + return GetPrivateKey(ctx, nil) + }, + }, + GetPrivateKey: GetPrivateKey, + } +} + +// GenerateWithSettings signs and optionally encrypts the token based on the context provided +func (s *DefaultStrategy) GenerateWithSettings(ctx context.Context, settings *KeyContext, claims MapClaims, header Mapper) (string, string, error) { + // ignoring the signing alg and kid for this implementation and just using the DefaultSigner implementation + rawToken, sig, err := s.DefaultSigner.Generate(ctx, claims, header) + if err != nil { + return "", "", err + } + + if settings.EncryptionAlgorithm == "" { + return rawToken, sig, err + } + + key, err := s.GetPrivateKey(ctx, settings) + if err != nil { + return "", "", err + } + + if t, ok := key.(*jose.JSONWebKey); ok { + key = t.Key + } + + var pubKey interface{} + switch t := key.(type) { + case *rsa.PrivateKey: + pubKey = &t.PublicKey + case *ecdsa.PrivateKey: + pubKey = &t.PublicKey + case jose.OpaqueSigner: + pubKey = t.Public() + default: + return "", "", fmt.Errorf("unable to decode token. Invalid PrivateKey type %T", key) + } + + eo := &jose.EncrypterOptions{} + eo = eo.WithContentType("JWT").WithType("JWT") + enc, err := jose.NewEncrypter( + jose.ContentEncryption(settings.EncryptionContentAlgorithm), + jose.Recipient{ + Algorithm: jose.KeyAlgorithm(settings.EncryptionAlgorithm), + Key: pubKey, + KeyID: settings.EncryptionKeyID, + }, + eo) + + if err != nil { + return "", "", fmt.Errorf("unable to build encrypter; err=%v", err) + } + + // Encrypt the token + o, err := enc.Encrypt([]byte(rawToken)) + if err != nil { + return "", "", fmt.Errorf("encrypting the token failed. err=%v", err) + } + + // Serialize the encrypted token + rawToken, err = o.CompactSerialize() + if err != nil { + return "", "", fmt.Errorf("serializing the encrypted token failed. err=%v", err) + } + + return rawToken, sig, err +} + +// DecryptWithSettings decrypts the token provided. If the token is not encrypted, the function should return an error. +func (s *DefaultStrategy) DecryptWithSettings(ctx context.Context, settings *KeyContext, token string) (string, error) { + + parsedToken, err := jose.ParseEncrypted(token) + if err != nil { + return "", fmt.Errorf("unable to parse the token") + } + + if settings == nil { + h := parsedToken.Header + enc, _ := h.ExtraHeaders[jose.HeaderKey("enc")].(string) + settings = &KeyContext{ + EncryptionKeyID: h.KeyID, + EncryptionAlgorithm: h.Algorithm, + EncryptionContentAlgorithm: enc, + } + } + + key, err := s.GetPrivateKey(ctx, settings) + var privateKey interface{} + switch t := key.(type) { + case *jose.JSONWebKey: + privateKey = t.Key + case jose.JSONWebKey: + privateKey = t.Key + case *rsa.PrivateKey: + privateKey = t + case *ecdsa.PrivateKey: + privateKey = t + case jose.OpaqueSigner: + switch tt := t.Public().Key.(type) { + case *rsa.PrivateKey: + privateKey = t + case *ecdsa.PrivateKey: + privateKey = t + default: + return "", fmt.Errorf("unsupported private / public key pairs: %T, %T", t, tt) + } + default: + return "", fmt.Errorf("unsupported private key type: %T", t) + } + + decrypted, err := parsedToken.Decrypt(privateKey) + if err != nil { + return "", err + } + + return string(decrypted), nil +} + +// ValidateWithSettings validates the signed token. If the token is not signed, the function should return an error. +func (s *DefaultStrategy) ValidateWithSettings(ctx context.Context, settings *KeyContext, token string) (string, error) { + // ignoring the signing alg and kid for this implementation and just using the DefaultSigner implementation + return s.DefaultSigner.Validate(ctx, token) +} diff --git a/token/jwt/strategy_jwt_test.go b/token/jwt/strategy_jwt_test.go new file mode 100644 index 000000000..395fd3678 --- /dev/null +++ b/token/jwt/strategy_jwt_test.go @@ -0,0 +1,87 @@ +// Copyright © 2023 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package jwt + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/ory/fosite/internal/gen" + "github.com/stretchr/testify/require" +) + +func TestEncryptJWT(t *testing.T) { + key := gen.MustRSAKey() + encryptKey := gen.MustRSAKey() + for k, tc := range []struct { + d string + keyContext *KeyContext + strategy Strategy + resetKey func(strategy Strategy) + }{ + { + d: "SameKeyStrategy", + keyContext: &KeyContext{ + EncryptionAlgorithm: "RSA-OAEP", + EncryptionContentAlgorithm: "A256GCM", + EncryptionKeyID: "samekey", + }, + strategy: NewDefaultStrategy(func(_ context.Context, context *KeyContext) (interface{}, error) { + return key, nil + }), + resetKey: func(strategy Strategy) { + key = gen.MustRSAKey() + }, + }, + { + d: "EncryptionKeyStrategy", + keyContext: &KeyContext{ + EncryptionAlgorithm: "RSA-OAEP", + EncryptionContentAlgorithm: "A256GCM", + EncryptionKeyID: "enc_key", + }, + strategy: NewDefaultStrategy(func(_ context.Context, context *KeyContext) (interface{}, error) { + if context == nil { + return key, nil + } + + if context.EncryptionKeyID == "enc_key" { + return encryptKey, nil + } + + return key, nil + }), + resetKey: func(strategy Strategy) { + key = gen.MustRSAKey() + encryptKey = gen.MustRSAKey() + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/strategy=%s", k, tc.d), func(t *testing.T) { + ctx := context.Background() + + // Reset private key + tc.resetKey(tc.strategy) + + claims := &JWTClaims{ + ExpiresAt: time.Now().UTC().Add(time.Hour), + } + + token, sig, err := tc.strategy.GenerateWithSettings(ctx, tc.keyContext, claims.ToMapClaims(), header) + require.NoError(t, err) + require.NotNil(t, token, "Token could not be generated") + + signedToken, err := tc.strategy.DecryptWithSettings(ctx, tc.keyContext, token) + require.NoError(t, err) + require.NotNil(t, signedToken, "Token could not be decrypted; token=%s", token) + + derivedSig, err := tc.strategy.Validate(ctx, signedToken) + require.NoError(t, err) + + require.EqualValues(t, sig, derivedSig, "Signature does not match") + }) + } +}