Skip to content

Commit

Permalink
Fix key pruning & rotation
Browse files Browse the repository at this point in the history
Correctly sets `MinDecryptionVersion` on policy to ensure expired keys are deleted

Also
* Removese `clock` from backend, cannot inject into Policy functions and thus makes it _hard_ to use for testing things like rotation and pruning.
* Logs the mount (not key name) during rotation; required adding “mount” to chain of calling functions
  • Loading branch information
kdubb committed Nov 16, 2023
1 parent ec30272 commit ceaf10e
Show file tree
Hide file tree
Showing 11 changed files with 300 additions and 86 deletions.
31 changes: 17 additions & 14 deletions plugin/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ type backend struct {
cachedConfig *Config
cachedConfigLock *sync.RWMutex
idGen uniqueIdGenerator
clock clock
}

// Factory returns a new backend as logical.Backend.
Expand All @@ -72,7 +71,6 @@ func createBackend(conf *logical.BackendConfig) (*backend, error) {

b.id = conf.BackendUUID
b.cachedConfigLock = new(sync.RWMutex)
b.clock = realClock{}
b.idGen = friendlyIdGenerator{}

b.Backend = &framework.Backend{
Expand Down Expand Up @@ -118,7 +116,7 @@ func (b *backend) periodic(ctx context.Context, req *logical.Request) error {
return err
}

policy, err := b.getPolicy(ctx, req.Storage, config)
policy, err := b.getPolicy(ctx, req.Storage, config, req.MountPoint)
if err != nil {
return err
}
Expand All @@ -145,7 +143,7 @@ func (b *backend) clean(_ context.Context) {
// Nothing to do
}

func (b *backend) getPolicy(ctx context.Context, stg logical.Storage, config *Config) (*keysutil.Policy, error) {
func (b *backend) getPolicy(ctx context.Context, stg logical.Storage, config *Config, mount string) (*keysutil.Policy, error) {

polReq := keysutil.PolicyRequest{
Upsert: true,
Expand Down Expand Up @@ -190,14 +188,14 @@ func (b *backend) getPolicy(ctx context.Context, stg logical.Storage, config *Co
return nil, err
}

if err := b.rotateIfNecessary(ctx, stg, policy, config); err != nil {
if err := b.rotateIfNecessary(ctx, stg, policy, config, mount); err != nil {
return nil, err
}

return policy, nil
}

func (b *backend) rotateIfNecessary(ctx context.Context, stg logical.Storage, policy *keysutil.Policy, config *Config) error {
func (b *backend) rotateIfNecessary(ctx context.Context, stg logical.Storage, policy *keysutil.Policy, config *Config, mount string) error {
policy.Lock(true)
defer policy.Unlock()

Expand All @@ -206,7 +204,7 @@ func (b *backend) rotateIfNecessary(ctx context.Context, stg logical.Storage, po
return nil
}

if latestKey.CreationTime.Add(config.KeyRotationPeriod).After(b.clock.now()) {
if latestKey.CreationTime.Add(config.KeyRotationPeriod).After(time.Now()) {
return nil
}

Expand All @@ -217,7 +215,7 @@ func (b *backend) rotateIfNecessary(ctx context.Context, stg logical.Storage, po

b.lockManager.InvalidatePolicy(policy.Name)

b.Logger().Info(fmt.Sprintf("Key Rotated: name=%s", policy.Name))
b.Logger().Info(fmt.Sprintf("Key Rotated: mount=%s", mount))

return nil
}
Expand Down Expand Up @@ -254,7 +252,7 @@ func (b *backend) pruneKeyVersions(ctx context.Context, stg logical.Storage, pol
)
}

if keyExpiresAt.After(b.clock.now()) {
if keyExpiresAt.After(time.Now()) {
break
}
}
Expand All @@ -273,21 +271,26 @@ func (b *backend) pruneKeyVersions(ctx context.Context, stg logical.Storage, pol
return nil
}

previousMinAvailableVersion := policy.MinDecryptionVersion

// Ensure that cache doesn't get corrupted in error cases
previousMinAvailableVersion := policy.MinAvailableVersion
previousMinDecryptionVersion := policy.MinDecryptionVersion

policy.MinAvailableVersion = unexpiredVersion
policy.MinDecryptionVersion = unexpiredVersion

if err := policy.Persist(ctx, stg); err != nil {
policy.MinDecryptionVersion = previousMinAvailableVersion
policy.MinAvailableVersion = previousMinAvailableVersion
policy.MinDecryptionVersion = previousMinDecryptionVersion
return err
}

logger.Info(
fmt.Sprintf(
"Key Trimmed: mount=%s, latest=%d, min-available=%d",
"Key Trimmed: mount=%s, latest=%d, min-available=%d, min-decryption=%d",
mount,
policy.MinAvailableVersion,
policy.LatestVersion,
policy.MinAvailableVersion,
policy.MinDecryptionVersion,
),
)

Expand Down
204 changes: 203 additions & 1 deletion plugin/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package jwtsecrets

import (
"context"
"github.com/go-test/deep"
"github.com/google/uuid"
"github.com/hashicorp/vault/sdk/logical"
"testing"
Expand All @@ -38,10 +39,211 @@ func getTestBackend(t *testing.T) (*backend, *logical.Storage) {
t.Fatalf("unable to create backend: %v", err)
}

b.clock = &fakeClock{time.Unix(0, 0)}
b.idGen = &fakeIDGenerator{0}

_ = b.clearConfig(context.Background(), config.StorageView)

return b, &config.StorageView
}

func TestRotate(t *testing.T) {
b, storage := getTestBackend(t)

_, err := writeConfig(b, storage, map[string]interface{}{
keyRotationDuration: "2s",
keyTokenTTL: "1s",
})
if err != nil {
t.Fatalf("%s\n", err)
}

err = writeRole(b, storage, "tester", "tester.example.com", map[string]interface{}{}, map[string]interface{}{})
if err != nil {
t.Fatalf("%s\n", err)
}

config, err := b.getConfig(context.Background(), *storage)
if err != nil {
t.Fatalf("%s\n", err)
}

policy, err := b.getPolicy(context.Background(), *storage, config, "test")
if err != nil {
t.Fatalf("%s\n", err)
}

// Pre-rotate checks
if diff := deep.Equal(policy.LatestVersion, 1); diff != nil {
t.Error("policy latest version", diff)
}
if diff := deep.Equal(policy.MinAvailableVersion, 0); diff != nil {
t.Error("policy min-available version", diff)
}
if diff := deep.Equal(policy.MinDecryptionVersion, 1); diff != nil {
t.Error("policy min-decryption version", diff)
}
if diff := deep.Equal(policy.ArchiveVersion, 1); diff != nil {
t.Error("policy archive version", diff)
}
if diff := deep.Equal(policy.ArchiveMinVersion, 0); diff != nil {
t.Error("policy archive-min version", diff)
}

time.Sleep(config.KeyRotationPeriod + 1)

// Post-rotate #1 checks
policy, err = b.getPolicy(context.Background(), *storage, config, "test")
if err != nil {
t.Fatalf("%s\n", err)
}

if diff := deep.Equal(policy.LatestVersion, 2); diff != nil {
t.Error("policy latest version", diff)
}
if diff := deep.Equal(policy.MinAvailableVersion, 0); diff != nil {
t.Error("policy min-available version", diff)
}
if diff := deep.Equal(policy.MinDecryptionVersion, 1); diff != nil {
t.Error("policy min-decryption version", diff)
}
if diff := deep.Equal(policy.ArchiveVersion, 2); diff != nil {
t.Error("policy archive version", diff)
}
if diff := deep.Equal(policy.ArchiveMinVersion, 0); diff != nil {
t.Error("policy archive-min version", diff)
}

policy, err = b.getPolicy(context.Background(), *storage, config, "test")
if err != nil {
t.Fatalf("%s\n", err)
}

// Should not have rotated yet
if diff := deep.Equal(policy.LatestVersion, 2); diff != nil {
t.Error("policy latest version", diff)
}
if diff := deep.Equal(policy.MinAvailableVersion, 0); diff != nil {
t.Error("policy min-available version", diff)
}
if diff := deep.Equal(policy.MinDecryptionVersion, 1); diff != nil {
t.Error("policy min-decryption version", diff)
}
if diff := deep.Equal(policy.ArchiveVersion, 2); diff != nil {
t.Error("policy archive version", diff)
}
if diff := deep.Equal(policy.ArchiveMinVersion, 0); diff != nil {
t.Error("policy archive-min version", diff)
}

time.Sleep(config.KeyRotationPeriod + 1)

policy, err = b.getPolicy(context.Background(), *storage, config, "test")
if err != nil {
t.Fatalf("%s\n", err)
}

// Post-rotate #2 checks
if diff := deep.Equal(policy.LatestVersion, 3); diff != nil {
t.Error("policy latest version", diff)
}
if diff := deep.Equal(policy.MinAvailableVersion, 0); diff != nil {
t.Error("policy min-available version", diff)
}
if diff := deep.Equal(policy.MinDecryptionVersion, 1); diff != nil {
t.Error("policy min-decryption version", diff)
}
if diff := deep.Equal(policy.ArchiveVersion, 3); diff != nil {
t.Error("policy archive version", diff)
}
if diff := deep.Equal(policy.ArchiveMinVersion, 0); diff != nil {
t.Error("policy archive-min version", diff)
}
}

func TestPrune(t *testing.T) {
b, storage := getTestBackend(t)

_, err := writeConfig(b, storage, map[string]interface{}{
keyRotationDuration: "2s",
keyTokenTTL: "1s",
})
if err != nil {
t.Fatalf("%s\n", err)
}

err = writeRole(b, storage, "tester", "tester.example.com", map[string]interface{}{}, map[string]interface{}{})
if err != nil {
t.Fatalf("%s\n", err)
}

config, err := b.getConfig(context.Background(), *storage)
if err != nil {
t.Fatalf("%s\n", err)
}

policy, err := b.getPolicy(context.Background(), *storage, config, "test")
if err != nil {
t.Fatalf("%s\n", err)
}
if diff := deep.Equal(policy.LatestVersion, 1); diff != nil {
t.Error("policy latest version", diff)
}

time.Sleep(config.KeyRotationPeriod + 1)

policy, err = b.getPolicy(context.Background(), *storage, config, "test")
if err != nil {
t.Fatalf("%s\n", err)
}
if diff := deep.Equal(policy.LatestVersion, 2); diff != nil {
t.Error("policy latest version", diff)
}

time.Sleep(config.KeyRotationPeriod + 1)

policy, err = b.getPolicy(context.Background(), *storage, config, "test")
if err != nil {
t.Fatalf("%s\n", err)
}
if diff := deep.Equal(policy.LatestVersion, 3); diff != nil {
t.Error("policy latest version", diff)
}

time.Sleep(config.KeyRotationPeriod + config.TokenTTL + 1)

err = b.pruneKeyVersions(context.Background(), *storage, policy, config, "test")
if err != nil {
t.Fatalf("%s\n", err)
}

// Post-prune checks
if diff := deep.Equal(policy.LatestVersion, 3); diff != nil {
t.Error("policy latest version", diff)
}
if diff := deep.Equal(policy.MinAvailableVersion, 3); diff != nil {
t.Error("policy min-available version", diff)
}
if diff := deep.Equal(policy.MinDecryptionVersion, 3); diff != nil {
t.Error("policy min-decryption version", diff)
}
if diff := deep.Equal(policy.ArchiveVersion, 3); diff != nil {
t.Error("policy archive version", diff)
}
if diff := deep.Equal(policy.ArchiveMinVersion, 3); diff != nil {
t.Error("policy archive-min version", diff)
}

time.Sleep(config.KeyRotationPeriod)

// Check that JWKS set contains the correct key versions.
// Should be 2 keys because pruning should have reduced it to 1 version
// and fetching will rotate again, leaving two keys.
jwks, err := FetchJWKS(b, storage)
if err != nil {
t.Fatalf("%s\n", err)
}

if diff := deep.Equal(len(jwks.Keys), 2); diff != nil {
t.Error("jwks key count", diff)
}
}
4 changes: 2 additions & 2 deletions plugin/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func (c *Config) copy() *Config {
return &cc
}

func (b *backend) saveConfig(ctx context.Context, stg logical.Storage, config *Config) error {
func (b *backend) saveConfig(ctx context.Context, stg logical.Storage, config *Config, mount string) error {
b.cachedConfigLock.Lock()
defer b.cachedConfigLock.Unlock()

Expand All @@ -166,7 +166,7 @@ func (b *backend) saveConfig(ctx context.Context, stg logical.Storage, config *C

b.Logger().Info("Key Format Rotation")

policy, err := b.getPolicy(ctx, stg, config)
policy, err := b.getPolicy(ctx, stg, config, mount)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion plugin/path_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func (b *backend) pathConfigWrite(ctx context.Context, req *logical.Request, d *
return logical.ErrorResponse("'%s' is greater that the max lease ttl", keyTokenTTL), logical.ErrInvalidRequest
}

if err := b.saveConfig(ctx, req.Storage, config); err != nil {
if err := b.saveConfig(ctx, req.Storage, config, req.MountPoint); err != nil {
return nil, err
}

Expand Down
16 changes: 9 additions & 7 deletions plugin/path_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ const (
func writeConfig(b *backend, storage *logical.Storage, config map[string]interface{}) (*logical.Response, error) {

req := &logical.Request{
Operation: logical.UpdateOperation,
Path: "config",
Storage: *storage,
Data: config,
Operation: logical.UpdateOperation,
Path: "config",
Storage: *storage,
Data: config,
MountPoint: "test",
}

resp, err := b.HandleRequest(context.Background(), req)
Expand All @@ -55,9 +56,10 @@ func TestDefaultConfig(t *testing.T) {
b, storage := getTestBackend(t)

req := &logical.Request{
Operation: logical.ReadOperation,
Path: "config",
Storage: *storage,
Operation: logical.ReadOperation,
Path: "config",
Storage: *storage,
MountPoint: "test",
}

resp, err := b.HandleRequest(context.Background(), req)
Expand Down
Loading

0 comments on commit ceaf10e

Please sign in to comment.