Skip to content
This repository has been archived by the owner on Apr 22, 2024. It is now read-only.

Commit

Permalink
Proper lifecycle for the JWKS provider
Browse files Browse the repository at this point in the history
  • Loading branch information
nacx committed Feb 13, 2024
1 parent 85e88e9 commit 8c7ac7c
Show file tree
Hide file tree
Showing 14 changed files with 244 additions and 174 deletions.
5 changes: 4 additions & 1 deletion cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ import (
"github.com/tetratelabs/telemetry"

"github.com/tetrateio/authservice-go/internal"
"github.com/tetrateio/authservice-go/internal/oidc"
"github.com/tetrateio/authservice-go/internal/server"
)

func main() {
var (
configFile = &internal.LocalConfigFile{}
logging = internal.NewLogSystem(log.New(), &configFile.Config)
envoyAuthz = server.NewExtAuthZFilter(&configFile.Config)
jwks = oidc.NewJWKSProvider()
envoyAuthz = server.NewExtAuthZFilter(&configFile.Config, jwks)
authzServer = server.New(&configFile.Config, envoyAuthz.Register)
)

Expand All @@ -49,6 +51,7 @@ func main() {
configFile, // load the configuration
logging, // set up the logging system
configLog, // log the configuration
jwks, // start the JWKS provider
authzServer, // start the server
&signal.Handler{}, // handle graceful termination
)
Expand Down
21 changes: 3 additions & 18 deletions internal/authz/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (

oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc"
"github.com/tetrateio/authservice-go/internal"
"github.com/tetrateio/authservice-go/internal/authz/oidc"
"github.com/tetrateio/authservice-go/internal/oidc"
)

var _ Handler = (*oidcHandler)(nil)
Expand All @@ -38,30 +38,15 @@ type oidcHandler struct {
}

// NewOIDCHandler creates a new OIDC implementation of the Handler interface.
func NewOIDCHandler(cfg *oidcv1.OIDCConfig) (Handler, error) {
func NewOIDCHandler(cfg *oidcv1.OIDCConfig, jwks oidc.JWKSProvider) (Handler, error) {
// TODO(nacx): Read the redis store config to configure the redi store
// TODO(nacx): Properly lifecycle the session store
store := oidc.NewMemoryStore(
oidc.Clock{},
time.Duration(cfg.AbsoluteSessionTimeout),
time.Duration(cfg.IdleSessionTimeout),
)

var (
jwks oidc.JWKSProvider
err error
)
if cfg.GetJwksFetcher() != nil {
jwks = oidc.NewDynamicJWKSProvider(
context.TODO(),
cfg.GetJwksFetcher().GetJwksUri(),
time.Duration(cfg.GetJwksFetcher().GetPeriodicFetchIntervalSec())*time.Second,
)
} else {
if jwks, err = oidc.NewStaticJWKSProvider(cfg.GetJwks()); err != nil {
return nil, err
}
}

return &oidcHandler{
log: internal.Logger(internal.Authz).With("type", "oidc"),
config: cfg,
Expand Down
114 changes: 0 additions & 114 deletions internal/authz/oidc/jwks.go

This file was deleted.

129 changes: 129 additions & 0 deletions internal/oidc/jwks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Copyright 2024 Tetrate
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package oidc

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
"time"

"github.com/lestrrat-go/jwx/jwk"
"github.com/tetratelabs/run"
"github.com/tetratelabs/telemetry"

oidcv1 "github.com/tetrateio/authservice-go/config/gen/go/v1/oidc"
"github.com/tetrateio/authservice-go/internal"
)

var (
// ErrJWKSParse is returned when the JWKS document cannot be parsed.
ErrJWKSParse = errors.New("error parsing JWKS document")
// ErrJWKSFetch is returned when the JWKS document cannot be fetched.
ErrJWKSFetch = errors.New("error fetching JWKS document")

_ run.Service = (*DefaultJWKSProvider)(nil)
)

// JWKSProvider provides a JWKS set for a given OIDC configuration.
type JWKSProvider interface {
// Get the JWKS for the given OIDC configuration
Get(context.Context, *oidcv1.OIDCConfig) (jwk.Set, error)
}

// DefaultJWKSProvider provides a JWKS set
type DefaultJWKSProvider struct {
log telemetry.Logger
cache *jwk.AutoRefresh
shutdown context.CancelFunc
}

// NewJWKSProvider returns a new JWKSProvider.
func NewJWKSProvider() *DefaultJWKSProvider {
return &DefaultJWKSProvider{
log: internal.Logger(internal.JWKS),
}
}

// Name of the JWKSProvider run.Unit
func (j *DefaultJWKSProvider) Name() string { return "JWKS" }

// Serve implements run.Service
func (j *DefaultJWKSProvider) Serve() error {
ctx, cancel := context.WithCancel(context.Background())
j.shutdown = cancel

ch := make(chan jwk.AutoRefreshError)
j.cache = jwk.NewAutoRefresh(ctx)
j.cache.ErrorSink(ch)

for {
select {
case err := <-ch:
j.log.Debug("jwks auto refresh error", "error", err)
case <-ctx.Done():
return nil
}
}
}

// GracefulStop implements run.Service
func (j *DefaultJWKSProvider) GracefulStop() {
if j.shutdown != nil {
j.shutdown()
}
}

// Get the JWKS for the given OIDC configuration
func (j *DefaultJWKSProvider) Get(ctx context.Context, config *oidcv1.OIDCConfig) (jwk.Set, error) {
if config.GetJwksFetcher() != nil {
return j.fetchDynamic(ctx, config.GetJwksFetcher())
}
return j.fetchStatic(config.GetJwks())
}

// fetchDynamic fetches the JWKS from the given URI. If the JWKS URI is already know, the JWKS will be returned from
// the cache. Otherwise, the JWKS will be fetched from the URI and the cache will be configured to periodically
// refresh the JWKS.
func (j *DefaultJWKSProvider) fetchDynamic(ctx context.Context, config *oidcv1.OIDCConfig_JwksFetcherConfig) (jwk.Set, error) {
if !j.cache.IsRegistered(config.JwksUri) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: config.SkipVerifyPeerCert}
client := &http.Client{Transport: transport}
refreshInterval := time.Duration(config.PeriodicFetchIntervalSec) * time.Second

j.cache.Configure(config.JwksUri,
jwk.WithHTTPClient(client),
jwk.WithRefreshInterval(refreshInterval),
)
}

jwks, err := j.cache.Fetch(ctx, config.JwksUri)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrJWKSFetch, err)
}
return jwks, nil
}

// fetchStatic parses the given raw JWKS document.
func (*DefaultJWKSProvider) fetchStatic(raw string) (jwk.Set, error) {
jwks, err := jwk.Parse([]byte(raw))
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrJWKSParse, err)
}
return jwks, nil
}
Loading

0 comments on commit 8c7ac7c

Please sign in to comment.