Skip to content

Commit

Permalink
[#14] fix: catch duplicate environment variable
Browse files Browse the repository at this point in the history
           settings on a command
  • Loading branch information
majohn-r committed Jun 16, 2024
1 parent 822582c commit 5a89885
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 32 deletions.
80 changes: 53 additions & 27 deletions build.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ import (
var (
fileSystem = afero.NewOsFs()
workingDir = ""
// function vars make it easy for tests to stub out functionality
exit = os.Exit
executor = cmd.Exec
printLine = fmt.Println
aggressive = flag.Bool("aggressive", false, "set to make dependency updates more aggressive")
// function vars make it easy for tests to stub out functionality
executor = cmd.Exec
exit = os.Exit
printLine = fmt.Println
setenv = os.Setenv
unsetenv = os.Unsetenv
)

// Clean deletes the named files, which must be located in, or in a subdirectory
Expand Down Expand Up @@ -219,17 +221,17 @@ func UpdateDependencies(a *goyek.A) bool {
return false
}
getCommand := directedCommand{command: "go get -u ./..."}
if *aggressive {
getCommand.envVars = append(getCommand.envVars, envVar{
name: "GOPROXY",
value: "direct",
unset: false,
})
}
tidyCommand := directedCommand{command: "go mod tidy"}
for _, dir := range dirs {
path := filepath.Join(WorkingDir(), dir)
getCommand.dir = path
if *aggressive {
getCommand.envVars = append(getCommand.envVars, envVar{
name: "GOPROXY",
value: "direct",
unset: false,
})
}
tidyCommand.dir = path
fmt.Printf("%q: updating dependencies\n", path)
if !getCommand.execute(a) {
Expand All @@ -250,22 +252,46 @@ func (dC directedCommand) execute(a *goyek.A) bool {
options[0] = cmd.Dir(dC.dir)
options[1] = cmd.Stderr(outputBuffer)
options[2] = cmd.Stdout(outputBuffer)
savedEnvVars := setupEnvVars(dC.envVars)
defer func() {
for _, v := range savedEnvVars {
if v.unset {
printLine("unsetting", v.name)
os.Unsetenv(v.name)
} else {
printLine("resetting", v.name, "to", v.value)
os.Setenv(v.name, v.value)
}
savedEnvVars, envVarsOK := setupEnvVars(dC.envVars)
state := envVarsOK
if state {
defer restoreEnvVars(savedEnvVars)
state = executor(a, dC.command, options...)
}
return state
}

func restoreEnvVars(saved []envVar) {
for _, v := range saved {
if v.unset {
printLine("restoring", v.name, "(unsetting)")
unsetenv(v.name)
} else {
printLine("restoring", v.name, "(resetting to", v.value+")")
setenv(v.name, v.value)
}
}()
return executor(a, dC.command, options...)
}
}

func checkEnvVars(input []envVar) bool {
if len(input) == 0 {
return true
}
distinctVar := map[string]bool{}
for _, v := range input {
if distinctVar[v.name] {
printLine("code error: detected attempt to set environment variable", v.name, "twice")
return false
}
distinctVar[v.name] = true
}
return true
}

func setupEnvVars(input []envVar) []envVar {
func setupEnvVars(input []envVar) ([]envVar, bool) {
if !checkEnvVars(input) {
return nil, false
}
savedEnvVars := []envVar{}
for _, envVariable := range input {
oldValue, defined := os.LookupEnv(envVariable.name)
Expand All @@ -277,13 +303,13 @@ func setupEnvVars(input []envVar) []envVar {
})
if envVariable.unset {
printLine("unsetting", envVariable.name)
os.Unsetenv(envVariable.name)
unsetenv(envVariable.name)
} else {
printLine("setting", envVariable.name, "to", envVariable.value)
os.Setenv(envVariable.name, envVariable.value)
setenv(envVariable.name, envVariable.value)
}
}
return savedEnvVars
return savedEnvVars, true
}

// VulnerabilityCheck runs the govulncheck tool, which checks for unresolved
Expand Down
107 changes: 102 additions & 5 deletions build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1038,8 +1038,9 @@ func TestUpdateDependencies(t *testing.T) {
func Test_setupEnvVars(t *testing.T) {
var1 := "VAR1"
var2 := "VAR2"
vars := []string{var1, var2}
originalVars := make([]envVar, 2)
var3 := "VAR3"
vars := []string{var1, var2, var3}
originalVars := make([]envVar, 3)
for k, s := range vars {
val, defined := os.LookupEnv(s)
originalVars[k] = envVar{
Expand All @@ -1061,9 +1062,26 @@ func Test_setupEnvVars(t *testing.T) {
os.Setenv(var1, val)
os.Unsetenv(var2)
tests := map[string]struct {
input []envVar
want []envVar
input []envVar
want []envVar
wantOk bool
}{
"error case": {
input: []envVar{
{
name: var3,
value: "foo",
unset: false,
},
{
name: var3,
value: "bar",
unset: false,
},
},
want: nil,
wantOk: false,
},
"thorough": {
input: []envVar{
{
Expand All @@ -1089,13 +1107,92 @@ func Test_setupEnvVars(t *testing.T) {
unset: true,
},
},
wantOk: true,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
if got := setupEnvVars(tt.input); !reflect.DeepEqual(got, tt.want) {
got, gotOk := setupEnvVars(tt.input)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("setupEnvVars() = %v, want %v", got, tt.want)
}
if gotOk != tt.wantOk {
t.Errorf("setupEnvVars() = %t, want %t", gotOk, tt.wantOk)
}
})
}
}

func Test_restoreEnvVars(t *testing.T) {
originalSetenv := setenv
originalUnsetenv := unsetenv
defer func() {
setenv = originalSetenv
unsetenv = originalUnsetenv
}()
var sets int
var unsets int
setenv = func(_, _ string) error {
sets++
return nil
}
unsetenv = func(_ string) error {
unsets++
return nil
}
tests := map[string]struct {
saved []envVar
wantSet int
wantUnset int
}{
"mix": {
saved: []envVar{
{
name: "v1",
value: "val1",
unset: false,
},
{
name: "v2",
value: "",
unset: true,
},
{
name: "v3",
value: "val3",
unset: false,
},
{
name: "v4",
value: "",
unset: true,
},
{
name: "v5",
value: "",
unset: true,
},
},
wantSet: 2,
wantUnset: 3,
},
"empty": {
saved: nil,
wantSet: 0,
wantUnset: 0,
},
}
for name, tt := range tests {
t.Run(name, func(t *testing.T) {
sets = 0
unsets = 0
restoreEnvVars(tt.saved)
if sets != tt.wantSet {
t.Errorf("restoreEnvVars set %d, want %d", sets, tt.wantSet)
}
if unsets != tt.wantUnset {
t.Errorf("restoreEnvVars unset %d, want %d", unsets, tt.wantUnset)
}
})
}
}

0 comments on commit 5a89885

Please sign in to comment.