Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [#631] JWT Encryption support for client authentication and ID Token generation #764

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ type OpenIDConnectClient interface {
GetTokenEndpointAuthSigningAlgorithm() string
}

// ClientWithAllowedVerificationKeys adds a security control to the client configuration to only allow
// specific verification keys. This ensures that a key that is valid for client X can't be used for client Y
// unless allowed. This becomes especially important for cases where the clients are controlled by third-parties
// and are issued specific keys from a central organization, which may be the OP's org or a central regulatory authority,
// and the security controls of the clients cannot be guaranteed.
type ClientWithAllowedVerificationKeys interface {
// AllowedVerificationKeys provides a list of key IDs that can be used in the JWT
// header for private_key_jwt authentication and for JWT bearer grant flow
AllowedVerificationKeys() []string
}

// ResponseModeClient represents a client capable of handling response_mode
type ResponseModeClient interface {
// GetResponseMode returns the response modes that client is allowed to send
Expand Down
186 changes: 57 additions & 129 deletions client_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,13 @@ package fosite

import (
"context"
"crypto/ecdsa"
"crypto/rsa"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"

"github.com/ory/x/errorsx"

"github.com/go-jose/go-jose/v3"
"github.com/pkg/errors"

"github.com/ory/fosite/token/jwt"
)

Expand All @@ -28,33 +22,15 @@ type ClientAuthenticationStrategy func(context.Context, *http.Request, url.Value
const clientAssertionJWTBearerType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"

func (f *Fosite) findClientPublicJWK(ctx context.Context, oidcClient OpenIDConnectClient, t *jwt.Token, expectsRSAKey bool) (interface{}, error) {
if set := oidcClient.GetJSONWebKeys(); set != nil {
return findPublicKey(t, set, expectsRSAKey)
}

if location := oidcClient.GetJSONWebKeysURI(); len(location) > 0 {
keys, err := f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, false)
if err != nil {
return nil, err
}

if key, err := findPublicKey(t, keys, expectsRSAKey); err == nil {
return key, nil
}

keys, err = f.Config.GetJWKSFetcherStrategy(ctx).Resolve(ctx, location, true)
if err != nil {
return nil, err
}

return findPublicKey(t, keys, expectsRSAKey)
if oidcClient.GetJSONWebKeys() == nil && oidcClient.GetJSONWebKeysURI() == "" {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request."))
}

return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The OAuth 2.0 Client has no JSON Web Keys set registered, but they are needed to complete the request."))
return findPublicJWK(ctx, f.Config, t, oidcClient.GetJSONWebKeysURI(), oidcClient.GetJSONWebKeys(), expectsRSAKey, ErrInvalidClient)
}

// AuthenticateClient authenticates client requests using the configured strategy
// `Fosite.ClientAuthenticationStrategy`, if nil it uses `Fosite.DefaultClientAuthenticationStrategy`
// `ClientAuthenticationStrategy`, if nil it uses `DefaultClientAuthenticationStrategy`
func (f *Fosite) AuthenticateClient(ctx context.Context, r *http.Request, form url.Values) (Client, error) {
if s := f.Config.GetClientAuthenticationStrategy(ctx); s != nil {
return s(ctx, r, form)
Expand All @@ -71,81 +47,70 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The client_assertion request parameter must be set when using client_assertion_type of '%s'.", clientAssertionJWTBearerType))
}

var clientID string
var client Client

token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (interface{}, error) {
var err error
clientID, _, err = clientCredentialsFromRequestBody(form, false)
if err != nil {
return nil, err
}
// Parse the assertion
token, parsedToken, isJWE, err := newToken(assertion, "client_assertion", ErrInvalidClient)
if err != nil {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to parse the client_assertion").WithWrap(err).WithDebug(err.Error()))
}

if clientID == "" {
claims := t.Claims
if sub, ok := claims["sub"].(string); !ok {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the client_assertion JSON Web Token is undefined."))
} else {
clientID = sub
}
}
claims := token.Claims

client, err = f.Store.GetClient(ctx, clientID)
if err != nil {
return nil, errorsx.WithStack(ErrInvalidClient.WithWrap(err).WithDebug(err.Error()))
}
// Validate client
clientID, _, err := clientCredentialsFromRequestBody(form, false)
if err != nil {
return nil, err
}

oidcClient, ok := client.(OpenIDConnectClient)
if !ok {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("The server configuration does not support OpenID Connect specific authentication methods."))
if clientID == "" {
if isJWE {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The 'client_id' must be part of the request when encrypted client_assertion is used."))
}

switch oidcClient.GetTokenEndpointAuthMethod() {
case "private_key_jwt":
break
case "none":
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This requested OAuth 2.0 client does not support client authentication, however 'client_assertion' was provided in the request."))
case "client_secret_post":
fallthrough
case "client_secret_basic":
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however 'client_assertion' was provided in the request.", oidcClient.GetTokenEndpointAuthMethod()))
case "client_secret_jwt":
fallthrough
default:
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however that method is not supported by this server.", oidcClient.GetTokenEndpointAuthMethod()))
if sub, ok := claims["sub"].(string); !ok {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the client_assertion JSON Web Token is undefined."))
} else {
clientID = sub
}
}

if oidcClient.GetTokenEndpointAuthSigningAlgorithm() != fmt.Sprintf("%s", t.Header["alg"]) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' uses signing algorithm '%s' but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", t.Header["alg"], oidcClient.GetTokenEndpointAuthSigningAlgorithm()))
}
switch t.Method {
case jose.RS256, jose.RS384, jose.RS512:
return f.findClientPublicJWK(ctx, oidcClient, t, true)
case jose.ES256, jose.ES384, jose.ES512:
return f.findClientPublicJWK(ctx, oidcClient, t, false)
case jose.PS256, jose.PS384, jose.PS512:
return f.findClientPublicJWK(ctx, oidcClient, t, true)
case jose.HS256, jose.HS384, jose.HS512:
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This authorization server does not support client authentication method 'client_secret_jwt'."))
default:
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' request parameter uses unsupported signing algorithm '%s'.", t.Header["alg"]))
}
})
client, err := f.Store.GetClient(ctx, clientID)
if err != nil {
// Do not re-process already enhanced errors
var e *jwt.ValidationError
if errors.As(err, &e) {
if e.Inner != nil {
return nil, e.Inner
}
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the integrity of the 'client_assertion' value.").WithWrap(err).WithDebug(err.Error()))
}
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client could not be authenticated.").WithWrap(err).WithDebug(err.Error()))
}

oidcClient, ok := client.(OpenIDConnectClient)
if !ok {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("The server configuration does not support OpenID Connect specific authentication methods."))
}

switch oidcClient.GetTokenEndpointAuthMethod() {
case "private_key_jwt":
break
case "none":
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This requested OAuth 2.0 client does not support client authentication, however 'client_assertion' was provided in the request."))
case "client_secret_post":
fallthrough
case "client_secret_basic":
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however 'client_assertion' was provided in the request.", oidcClient.GetTokenEndpointAuthMethod()))
case "client_secret_jwt":
fallthrough
default:
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("This requested OAuth 2.0 client only supports client authentication method '%s', however that method is not supported by this server.", oidcClient.GetTokenEndpointAuthMethod()))
}

// Validate signature
if !isJWE && oidcClient.GetTokenEndpointAuthSigningAlgorithm() != "" && oidcClient.GetTokenEndpointAuthSigningAlgorithm() != parsedToken.Headers[0].Algorithm {
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The client_assertion uses signing algorithm '%s', but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", parsedToken.Headers[0].Algorithm, oidcClient.GetTokenEndpointAuthSigningAlgorithm()))
}

ctx = context.WithValue(ctx, AssertionTypeContextKey, "client_assertion")
ctx = context.WithValue(ctx, BaseErrorContextKey, ErrInvalidClient)
if token, parsedToken, err = ValidateParsedAssertionWithClient(ctx, f.Config, assertion, token, parsedToken, oidcClient, false); err != nil {
return nil, err
} else if err := token.Claims.Valid(); err != nil {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithWrap(err).WithDebug(err.Error()))
}

claims := token.Claims
claims = token.Claims

var jti string
if !claims.VerifyIssuer(clientID, true) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client."))
Expand Down Expand Up @@ -255,43 +220,6 @@ func (f *Fosite) checkClientSecret(ctx context.Context, client Client, clientSec
return err
}

func findPublicKey(t *jwt.Token, set *jose.JSONWebKeySet, expectsRSAKey bool) (interface{}, error) {
keys := set.Keys
if len(keys) == 0 {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The retrieved JSON Web Key Set does not contain any key."))
}

kid, ok := t.Header["kid"].(string)
if ok {
keys = set.Key(kid)
}

if len(keys) == 0 {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("The JSON Web Token uses signing key with kid '%s', which could not be found.", kid))
}

for _, key := range keys {
if key.Use != "sig" {
continue
}
if expectsRSAKey {
if k, ok := key.Key.(*rsa.PublicKey); ok {
return k, nil
}
} else {
if k, ok := key.Key.(*ecdsa.PublicKey); ok {
return k, nil
}
}
}

if expectsRSAKey {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find RSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid))
} else {
return nil, errorsx.WithStack(ErrInvalidRequest.WithHintf("Unable to find ECDSA public key with use='sig' for kid '%s' in JSON Web Key Set.", kid))
}
}

func clientCredentialsFromRequest(r *http.Request, form url.Values) (clientID, clientSecret string, err error) {
if id, secret, ok := r.BasicAuth(); !ok {
return clientCredentialsFromRequestBody(form, true)
Expand Down
67 changes: 60 additions & 7 deletions client_authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,31 @@ import (
"github.com/ory/fosite/storage"
)

func encryptAssertionWithRSAKey(t *testing.T, token string, pubKey *rsa.PublicKey) string {
eo := &jose.EncrypterOptions{}
eo = eo.WithContentType("JWT").WithType("JWT")
enc, err := jose.NewEncrypter(
jose.ContentEncryption("A256GCM"),
jose.Recipient{
Algorithm: jose.KeyAlgorithm("RSA-OAEP"),
Key: pubKey,
KeyID: "enc_key",
},
eo)

require.NoError(t, err, "unable to build encrypter; err=%v", err)

// Encrypt the token
o, err := enc.Encrypt([]byte(token))
require.NoError(t, err, "encrypting the token failed. err=%v", err)

// Serialize the encrypted token
token, err = o.CompactSerialize()
require.NoError(t, err, "serializing the encrypted token failed. err=%v", err)

return token
}

func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jose.RS256, claims)
token.Header["kid"] = kid
Expand Down Expand Up @@ -75,14 +100,22 @@ func TestAuthenticateClient(t *testing.T) {
const at = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"

hasher := &BCrypt{Config: &Config{HashCost: 6}}
encKey := gen.MustRSAKey()

config := &Config{
JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(),
ClientSecretsHasher: hasher,
TokenURL: "token-url",
HTTPClient: retryablehttp.NewClient(),
JWTStrategy: jwt.NewDefaultStrategy(
func(ctx context.Context, context *jwt.KeyContext) (interface{}, error) {
return encKey, nil
}),
}

f := &Fosite{
Store: storage.NewMemoryStore(),
Config: &Config{
JWKSFetcherStrategy: NewDefaultJWKSFetcherStrategy(),
ClientSecretsHasher: hasher,
TokenURL: "token-url",
HTTPClient: retryablehttp.NewClient(),
},
Store: storage.NewMemoryStore(),
Config: config,
}

barSecret, err := hasher.Hash(context.TODO(), []byte("bar"))
Expand Down Expand Up @@ -300,6 +333,26 @@ func TestAuthenticateClient(t *testing.T) {
}, rsaKey, "kid-foo")}, "client_assertion_type": []string{at}},
r: new(http.Request),
},
{
d: "should pass with proper encrypted RSA assertion when JWKs are set within the client and client_id is set in the request",
client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: rsaJwks, TokenEndpointAuthMethod: "private_key_jwt"},
form: url.Values{
"client_id": []string{"bar"},
"client_assertion": {
encryptAssertionWithRSAKey(t,
mustGenerateRSAAssertion(t, jwt.MapClaims{
"sub": "bar",
"exp": time.Now().Add(time.Hour).Unix(),
"iss": "bar",
"jti": "12345",
"aud": "token-url",
}, rsaKey, "kid-foo"),
&encKey.PublicKey),
},
"client_assertion_type": []string{at},
},
r: new(http.Request),
},
{
d: "should pass with proper ECDSA assertion when JWKs are set within the client and client_id is not set in the request",
client: &DefaultOpenIDConnectClient{DefaultClient: &DefaultClient{ID: "bar", Secret: barSecret}, JSONWebKeys: ecdsaJwks, TokenEndpointAuthMethod: "private_key_jwt", TokenEndpointAuthSigningAlgorithm: "ES256"},
Expand Down
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,9 @@ type PushedAuthorizeRequestConfigProvider interface {
// must contain the PAR request_uri.
EnforcePushedAuthorize(ctx context.Context) bool
}

// JWTStrategyProvider returns the provider for configuring the JWT strategy.
type JWTStrategyProvider interface {
// GetJWTStrategy returns the JWT strategy.
GetJWTStrategy(ctx context.Context) jwt.Strategy
}
9 changes: 9 additions & 0 deletions config_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ var (
_ RevocationHandlersProvider = (*Config)(nil)
_ PushedAuthorizeRequestHandlersProvider = (*Config)(nil)
_ PushedAuthorizeRequestConfigProvider = (*Config)(nil)
_ JWTStrategyProvider = (*Config)(nil)
)

type Config struct {
Expand Down Expand Up @@ -212,6 +213,9 @@ type Config struct {

// IsPushedAuthorizeEnforced enforces pushed authorization request for /authorize
IsPushedAuthorizeEnforced bool

// JWTStrategy is used to provide additional JWT encrypt/decrypt/sign/verify capabilities
JWTStrategy jwt.Strategy
}

func (c *Config) GetGlobalSecret(ctx context.Context) ([]byte, error) {
Expand Down Expand Up @@ -488,3 +492,8 @@ func (c *Config) GetPushedAuthorizeContextLifespan(ctx context.Context) time.Dur
func (c *Config) EnforcePushedAuthorize(ctx context.Context) bool {
return c.IsPushedAuthorizeEnforced
}

// GetJWTStrategy returns the JWT strategy.
func (c *Config) GetJWTStrategy(ctx context.Context) jwt.Strategy {
return c.JWTStrategy
}
3 changes: 3 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ const (
AuthorizeResponseContextKey = ContextKey("authorizeResponse")
// PushedAuthorizeResponseContextKey is the response context
PushedAuthorizeResponseContextKey = ContextKey("pushedAuthorizeResponse")

AssertionTypeContextKey = ContextKey("assertionType")
BaseErrorContextKey = ContextKey("baseError")
)
Loading