Skip to content
This repository has been archived by the owner on Apr 22, 2024. It is now read-only.

Commit

Permalink
Complete Redis session store implementation (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
nacx authored Feb 15, 2024
1 parent 1e9ecdf commit 2bcc7ba
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 31 deletions.
71 changes: 70 additions & 1 deletion e2e/redis/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestRedisTokenResponse(t *testing.T) {
require.NoError(t, err)
require.Nil(t, tr)

// Create a session and verify it's added and accessed time
// Create a session and verify it's added and accessed time is set
tr = &oidc.TokenResponse{
IDToken: newToken(),
AccessToken: newToken(),
Expand All @@ -66,6 +66,75 @@ func TestRedisTokenResponse(t *testing.T) {
require.Greater(t, ttl, time.Duration(0))
}

func TestRedisAuthorizationState(t *testing.T) {
opts, err := redis.ParseURL(redisURL)
require.NoError(t, err)
client := redis.NewClient(opts)

store, err := oidc.NewRedisStore(&oidc.Clock{}, client, 0, 1*time.Minute)
require.NoError(t, err)

ctx := context.Background()

as, err := store.GetAuthorizationState(ctx, "s1")
require.NoError(t, err)
require.Nil(t, as)

// Create a session and verify it's added and accessed time is set
as = &oidc.AuthorizationState{
State: "state",
Nonce: "nonce",
RequestedURL: "https://example.com",
}
require.NoError(t, store.SetAuthorizationState(ctx, "s1", as))

// Verify that the right state is returned
got, err := store.GetAuthorizationState(ctx, "s1")
require.NoError(t, err)
require.Equal(t, as, got)

// Verify that the token TTL has been set
ttl := client.TTL(ctx, "s1").Val()
require.Greater(t, ttl, time.Duration(0))
}

func TestSessionExpiration(t *testing.T) {
opts, err := redis.ParseURL(redisURL)
require.NoError(t, err)
client := redis.NewClient(opts)

store, err := oidc.NewRedisStore(&oidc.Clock{}, client, 2*time.Second, 0)
require.NoError(t, err)

ctx := context.Background()

t.Run("expire-token", func(t *testing.T) {
tr := &oidc.TokenResponse{
IDToken: newToken(),
AccessToken: newToken(),
AccessTokenExpiresAt: time.Now().Add(30 * time.Minute),
}
require.NoError(t, store.SetTokenResponse(ctx, "s1", tr))
require.Eventually(t, func() bool {
got, err := store.GetTokenResponse(ctx, "s1")
return got == nil && err == nil
}, 3*time.Second, 1*time.Second)
})

t.Run("expire-state", func(t *testing.T) {
as := &oidc.AuthorizationState{
State: "state",
Nonce: "nonce",
RequestedURL: "https://example.com",
}
require.NoError(t, store.SetAuthorizationState(ctx, "s1", as))
require.Eventually(t, func() bool {
got, err := store.GetAuthorizationState(ctx, "s1")
return got == nil && err == nil
}, 3*time.Second, 1*time.Second)
})
}

func newToken() string {
token, _ := jwt.NewBuilder().
Issuer("authservice").
Expand Down
10 changes: 9 additions & 1 deletion internal/oidc/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ func (j *DefaultJWKSProvider) Get(ctx context.Context, config *oidcv1.OIDCConfig
// the cache. Otherwise, the JWKS will be fetched from the URI and the cache will be configured to periodically
// refresh the JWKS.
func (j *DefaultJWKSProvider) fetchDynamic(ctx context.Context, config *oidcv1.OIDCConfig_JwksFetcherConfig) (jwk.Set, error) {
log := j.log.Context(ctx)

if !j.cache.IsRegistered(config.JwksUri) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: config.SkipVerifyPeerCert}
Expand All @@ -112,12 +114,16 @@ func (j *DefaultJWKSProvider) fetchDynamic(ctx context.Context, config *oidcv1.O
refreshInterval = DefaultFetchInterval
}

log.Info("configuring JWKS auto refresh", "jwks", config.JwksUri, "interval", refreshInterval, "skip_verify", config.SkipVerifyPeerCert)

j.cache.Configure(config.JwksUri,
jwk.WithHTTPClient(client),
jwk.WithRefreshInterval(refreshInterval),
)
}

log.Debug("fetching JWKS", "jwks", config.JwksUri)

jwks, err := j.cache.Fetch(ctx, config.JwksUri)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrJWKSFetch, err)
Expand All @@ -126,7 +132,9 @@ func (j *DefaultJWKSProvider) fetchDynamic(ctx context.Context, config *oidcv1.O
}

// fetchStatic parses the given raw JWKS document.
func (*DefaultJWKSProvider) fetchStatic(raw string) (jwk.Set, error) {
func (j *DefaultJWKSProvider) fetchStatic(raw string) (jwk.Set, error) {
j.log.Debug("parsing static JWKS", "jwks", raw)

jwks, err := jwk.Parse([]byte(raw))
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrJWKSParse, err)
Expand Down
46 changes: 36 additions & 10 deletions internal/oidc/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,20 @@ func NewMemoryStore(clock *Clock, absoluteSessionTimeout, idleSessionTimeout tim
}
}

func (m *memoryStore) SetTokenResponse(_ context.Context, sessionID string, tokenResponse *TokenResponse) error {
m.set(sessionID, func(s *session) {
func (m *memoryStore) SetTokenResponse(ctx context.Context, sessionID string, tokenResponse *TokenResponse) error {
log := m.log.Context(ctx).With("session-id", sessionID)
log.Debug("setting token response", "token_response", tokenResponse)

m.set(ctx, sessionID, func(s *session) {
s.tokenResponse = tokenResponse
})
return nil
}

func (m *memoryStore) GetTokenResponse(_ context.Context, sessionID string) (*TokenResponse, error) {
func (m *memoryStore) GetTokenResponse(ctx context.Context, sessionID string) (*TokenResponse, error) {
log := m.log.Context(ctx).With("session-id", sessionID)
log.Debug("getting token response")

m.mu.Lock()
defer m.mu.Unlock()

Expand All @@ -68,14 +74,20 @@ func (m *memoryStore) GetTokenResponse(_ context.Context, sessionID string) (*To
return s.tokenResponse, nil
}

func (m *memoryStore) SetAuthorizationState(_ context.Context, sessionID string, authorizationState *AuthorizationState) error {
m.set(sessionID, func(s *session) {
func (m *memoryStore) SetAuthorizationState(ctx context.Context, sessionID string, authorizationState *AuthorizationState) error {
log := m.log.Context(ctx).With("session-id", sessionID)
log.Debug("setting authorization state", "state", authorizationState)

m.set(ctx, sessionID, func(s *session) {
s.authorizationState = authorizationState
})
return nil
}

func (m *memoryStore) GetAuthorizationState(_ context.Context, sessionID string) (*AuthorizationState, error) {
func (m *memoryStore) GetAuthorizationState(ctx context.Context, sessionID string) (*AuthorizationState, error) {
log := m.log.Context(ctx).With("session-id", sessionID)
log.Debug("getting authorization state")

m.mu.Lock()
defer m.mu.Unlock()

Expand All @@ -88,7 +100,10 @@ func (m *memoryStore) GetAuthorizationState(_ context.Context, sessionID string)
return s.authorizationState, nil
}

func (m *memoryStore) ClearAuthorizationState(_ context.Context, sessionID string) error {
func (m *memoryStore) ClearAuthorizationState(ctx context.Context, sessionID string) error {
log := m.log.Context(ctx).With("session-id", sessionID)
log.Debug("clearing authorization state")

m.mu.Lock()
defer m.mu.Unlock()

Expand All @@ -100,7 +115,10 @@ func (m *memoryStore) ClearAuthorizationState(_ context.Context, sessionID strin
return nil
}

func (m *memoryStore) RemoveSession(_ context.Context, sessionID string) error {
func (m *memoryStore) RemoveSession(ctx context.Context, sessionID string) error {
log := m.log.Context(ctx).With("session-id", sessionID)
log.Debug("removing session")

m.mu.Lock()
defer m.mu.Unlock()

Expand All @@ -109,7 +127,10 @@ func (m *memoryStore) RemoveSession(_ context.Context, sessionID string) error {
return nil
}

func (m *memoryStore) RemoveAllExpired(context.Context) error {
func (m *memoryStore) RemoveAllExpired(ctx context.Context) error {
log := m.log.Context(ctx)
log.Debug("removing expired sessions")

var (
earliestTimeAddedToKeep = m.clock.Now().Add(-m.absoluteSessionTimeout)
earliestTimeIdleToKeep = m.clock.Now().Add(-m.idleSessionTimeout)
Expand All @@ -125,6 +146,7 @@ func (m *memoryStore) RemoveAllExpired(context.Context) error {
expiredBasedOnIdleTime := shouldCheckIdleTimeout && s.accessed.Before(earliestTimeIdleToKeep)

if expiredBasedOnTimeAdded || expiredBasedOnIdleTime {
log.Debug("removing expired session", "session-id", sessionID)
delete(m.sessions, sessionID)
}
}
Expand All @@ -133,7 +155,9 @@ func (m *memoryStore) RemoveAllExpired(context.Context) error {
}

// set the given session with the given setter function and record the access time.
func (m *memoryStore) set(sessionID string, setter func(s *session)) {
func (m *memoryStore) set(ctx context.Context, sessionID string, setter func(s *session)) {
log := m.log.Context(ctx).With("session-id", sessionID)

m.mu.Lock()
defer m.mu.Unlock()

Expand All @@ -146,6 +170,8 @@ func (m *memoryStore) set(sessionID string, setter func(s *session)) {
setter(s)
m.sessions[sessionID] = s
}

log.Debug("updating last access", "accessed", s.accessed)
}

// session holds the data of a session stored in the in-memory cache
Expand Down
Loading

0 comments on commit 2bcc7ba

Please sign in to comment.