Skip to content

Commit

Permalink
Resolve path with tag resolver (#11)
Browse files Browse the repository at this point in the history
* resolve path using go tags instead of viper-decode-hook
  • Loading branch information
Argelbargel authored Sep 18, 2023
1 parent 5ddc359 commit 284513f
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 43 deletions.
35 changes: 6 additions & 29 deletions internal/app/vault_raft_snapshot_agent/config/rattlesnake.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package config
import (
"fmt"
"path/filepath"
"reflect"
"strings"

"github.com/creasty/defaults"
Expand All @@ -14,8 +13,6 @@ import (
"github.com/spf13/viper"
)

type Path string

// a rattlesnake is a viper adapted to our needs ;-)
type rattlesnake struct {
v *viper.Viper
Expand Down Expand Up @@ -72,20 +69,19 @@ func (r rattlesnake) Unmarshal(config interface{}) error {
return fmt.Errorf("could not bind env vars for configuration: %s", err)
}

decodeHook := mapstructure.ComposeDecodeHookFunc(
mapstructure.StringToTimeDurationHookFunc(),
mapstructure.StringToSliceHookFunc(","),
newPathResolverHook(filepath.Dir(r.ConfigFileUsed())),
)

if err := r.v.Unmarshal(config, viper.DecodeHook(decodeHook)); err != nil {
if err := r.v.Unmarshal(config); err != nil {
return err
}

if err := defaults.Set(config); err != nil {
return fmt.Errorf("could not set configuration's default-values: %s", err)
}

pathResolver := newPathResolver(filepath.Dir(r.ConfigFileUsed()))
if err := pathResolver.Resolve(config); err != nil {
return fmt.Errorf("could not resolve relative paths in configuration: %s", err)
}

validate := validator.New()
if err := validate.Struct(config); err != nil {
return err
Expand All @@ -106,25 +102,6 @@ func (r rattlesnake) IsConfigurationNotFoundError(err error) bool {
return notfound
}

func newPathResolverHook(workdir string) mapstructure.DecodeHookFuncType {
return func(dataType reflect.Type, targetType reflect.Type, data interface{}) (interface{}, error) {
if dataType.Kind() != reflect.String {
return data, nil
}

if targetType != reflect.TypeOf(Path("")) {
return data, nil
}

path := data.(string)
if !filepath.IsAbs(path) {
path = filepath.Join(workdir, path)
}

return Path(filepath.Clean(path)), nil
}
}

// implements automatic unmarshalling from environment variables
// see https://github.com/spf13/viper/pull/1429
// can be removed if that pr is merged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

type rattlesnakeConfigStub struct {
Path Path `default:"/test/file"`
Path string `default:"/test/file" resolve-path:""`
Url string `validate:"omitempty,http_url"`
}

Expand All @@ -30,7 +30,7 @@ func TestUnmarshalResolvesRelativePaths(t *testing.T) {
err = rattlesnake.Unmarshal(&config)

assert.NoError(t, err, "Unmarshal failed unexpectedly")
assert.Equal(t, Path(filepath.Clean(fmt.Sprintf("%s/file.ext", wd))), config.Path)
assert.Equal(t, filepath.Clean(fmt.Sprintf("%s/file.ext", wd)), config.Path)
}

func TestUnmarshalSetsDefaultValues(t *testing.T) {
Expand All @@ -40,7 +40,7 @@ func TestUnmarshalSetsDefaultValues(t *testing.T) {
err := rattlesnake.Unmarshal(&config)

assert.NoError(t, err, "Unmarshal failed unexpectedly")
assert.Equal(t, Path("/test/file"), config.Path)
assert.Equal(t, "/test/file", config.Path)
}

func TestUnmarshalValidatesValues(t *testing.T) {
Expand Down
86 changes: 86 additions & 0 deletions internal/app/vault_raft_snapshot_agent/config/resolve-path-tag.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package config

import (
"errors"
"path/filepath"
"reflect"
"strings"
)

const (
tagFieldName = "resolve-path"
)

var (
errorInvalidType error = errors.New("subject must be a struct passed by pointer")
)

type pathResolver struct {
baseDir string
}

func newPathResolver(baseDir string) pathResolver {
return pathResolver{baseDir}
}

func (r pathResolver) Resolve(subject interface{}) error {
if reflect.TypeOf(subject).Kind() != reflect.Ptr {
return errorInvalidType
}

s := reflect.ValueOf(subject).Elem()

return r.resolve(s)
}

func (r pathResolver) resolve(value reflect.Value) error {
t := value.Type()

if t.Kind() != reflect.Struct {
return errorInvalidType
}

for i := 0; i < t.NumField(); i++ {
f := value.Field(i)

if !f.CanSet() {
continue
}

if f.Kind() == reflect.Ptr {
f = f.Elem()
}

if f.Kind() == reflect.Struct {
if err := r.resolve(f); err != nil {
return err
}
}

if f.Kind() != reflect.String || f.String() == "" {
continue
}

if baseDir, present := t.Field(i).Tag.Lookup(tagFieldName); present {
if err := r.resolvePath(f, baseDir); err != nil {
return err
}
}
}

return nil
}

func (r pathResolver) resolvePath(field reflect.Value, baseDir string) error {
path := field.String()
if baseDir == "" {
baseDir = r.baseDir
}

if !filepath.IsAbs(path) && !strings.HasPrefix(path, "/") {
path = filepath.Join(baseDir, path)
field.Set(reflect.ValueOf(path).Convert(field.Type()))
}

return nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package config

import (
"fmt"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

func TestResolvesRelativePaths(t *testing.T) {
var test struct {
Path string `resolve-path:""`
FixedPath string `resolve-path:"/tmp/"`
Other string
AbsolutePath string `resolve-path:""`
}
test.Path = "./relative"
test.FixedPath = "./fixed"
test.Other = "./other"
test.AbsolutePath = "/test/abs"

dir := t.TempDir()
resolver := newPathResolver(dir)
err := resolver.Resolve(&test)

assert.NoError(t, err, "resolver.resolve failed unexepectedly")

assert.Equal(t, filepath.Clean(fmt.Sprintf("%s/relative", dir)), test.Path)
assert.Equal(t, filepath.Clean("/tmp/fixed"), test.FixedPath)
assert.Equal(t, "/test/abs", test.AbsolutePath)
assert.Equal(t, "./other", test.Other)
}

func TestResolvesRecursively(t *testing.T) {
type inner struct {
Path string `resolve-path:""`
}

innerPtr := inner{"./innerPtr"}

var outer struct {
Inner inner
InnerPtr *inner
}
outer.Inner.Path = "./inner"
outer.InnerPtr = &innerPtr

dir := t.TempDir()
resolver := newPathResolver(dir)
err := resolver.Resolve(&outer)

assert.NoError(t, err, "resolver.resolve failed unexepectedly")

assert.Equal(t, filepath.Clean(fmt.Sprintf("%s/inner", dir)), outer.Inner.Path)
assert.Equal(t, filepath.Clean(fmt.Sprintf("%s/innerPtr", dir)), innerPtr.Path)

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func defaultJwtPath(def string) string {
return "/var/run/secrets/kubernetes.io/serviceaccount/token"
}

func relativeTo(configFile string, file string) config.Path {
func relativeTo(configFile string, file string) string {
if !filepath.IsAbs(file) && !strings.HasPrefix(file, "/") {
file = filepath.Join(filepath.Dir(configFile), file)
}
Expand All @@ -40,7 +40,7 @@ func relativeTo(configFile string, file string) config.Path {
file = filepath.Clean(file)
}

return config.Path(file)
return file
}

func TestReadCompleteConfig(t *testing.T) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package auth

import (
"github.com/Argelbargel/vault-raft-snapshot-agent/internal/app/vault_raft_snapshot_agent/config"
"github.com/hashicorp/vault/api/auth/kubernetes"
)

type KubernetesAuthConfig struct {
Path string `default:"kubernetes"`
Role string `validate:"required_if=Empty false"`
JWTPath config.Path `default:"/var/run/secrets/kubernetes.io/serviceaccount/token" validate:"omitempty,file,required_if=Empty false"`
Path string `default:"kubernetes"`
Role string `validate:"required_if=Empty false"`
JWTPath string `default:"/var/run/secrets/kubernetes.io/serviceaccount/token" resolve-path:"" validate:"omitempty,file,required_if=Empty false"`
Empty bool
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"testing"

"github.com/Argelbargel/vault-raft-snapshot-agent/internal/app/vault_raft_snapshot_agent/config"
"github.com/Argelbargel/vault-raft-snapshot-agent/internal/app/vault_raft_snapshot_agent/test"

"github.com/hashicorp/vault/api/auth/kubernetes"
Expand All @@ -15,9 +14,9 @@ import (
func TestCreateKubernetesAuth(t *testing.T) {
jwtPath := fmt.Sprintf("%s/jwt", t.TempDir())
config := KubernetesAuthConfig{
Role: "test-role",
JWTPath: config.Path(jwtPath),
Path: "test-path",
Role: "test-role",
JWTPath: jwtPath,
Path: "test-path",
}

err := test.WriteFile(t, jwtPath, "test")
Expand All @@ -26,7 +25,7 @@ func TestCreateKubernetesAuth(t *testing.T) {
expectedAuthMethod, err := kubernetes.NewKubernetesAuth(
config.Role,
kubernetes.WithMountPath(config.Path),
kubernetes.WithServiceAccountTokenPath(string(config.JWTPath)),
kubernetes.WithServiceAccountTokenPath(config.JWTPath),
)
assert.NoError(t, err, "NewKubernetesAuth failed unexpectedly")

Expand Down

0 comments on commit 284513f

Please sign in to comment.