Skip to content

Commit

Permalink
fix: update audit command to string (#90)
Browse files Browse the repository at this point in the history
* fix: update audit command to string

Signed-off-by: chenk <[email protected]>

* fix: update audit command to string

Signed-off-by: chenk <[email protected]>

* fix: update audit command to string

Signed-off-by: chenk <[email protected]>

---------

Signed-off-by: chenk <[email protected]>
  • Loading branch information
chen-keinan authored Jan 3, 2024
1 parent 682ebbc commit e1554f2
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 11 deletions.
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ var RootCmd = &cobra.Command{
glog.V(2).Info("Returning a PowerShell (Auditer) \n")
return ps
})
return runChecks(b)
return runChecks(b, ps.OsType)
},
}

Expand Down
22 changes: 20 additions & 2 deletions cmd/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ import (
"fmt"
"path/filepath"

"github.com/aquasecurity/bench-common/check"
commonCheck "github.com/aquasecurity/bench-common/check"
"github.com/aquasecurity/bench-common/util"
"github.com/golang/glog"
)

func runChecks(b commonCheck.Bench) error {
func runChecks(b commonCheck.Bench, serverType string) error {
var version string
var err error

Expand Down Expand Up @@ -51,17 +52,34 @@ func runChecks(b commonCheck.Bench) error {

summary := runControls(controls, checkList)

controls = updateControlCheck(controls, serverType)

return outputResults(controls, summary)
}

func updateControlCheck(controls *check.Controls, osType string) *check.Controls {
// `runControls` can detect some items without correct `cmd`, and the state will be set `SKIP`
// We should remove skipped controls, because there is no way to print them.
for _, group := range controls.Groups {
for i := len(group.Checks) - 1; i >= 0; i-- {
if group.Checks[i].State == commonCheck.SKIP {
group.Checks = append(group.Checks[:i], group.Checks[i+1:]...)
}
group.Checks[i].Audit = getOsTypeAuditCommand(group.Checks[i].Audit, osType)
}
}
return outputResults(controls, summary)
return controls
}

func getOsTypeAuditCommand(audit interface{}, serverType string) string {
if a, ok := audit.(map[interface{}]interface{}); ok {
if cmd, ok := a["cmd"].(map[interface{}]interface{}); ok {
if val, ok := cmd[serverType].(string); ok {
return val
}
}
}
return fmt.Sprintf("%v", audit)
}

// loadConfig finds the correct config dir based on the kubernetes version,
Expand Down
34 changes: 33 additions & 1 deletion cmd/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import (
"strings"
"testing"

"github.com/aquasecurity/bench-common/check"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)

func TestLoadConfig(t *testing.T) {
Expand Down Expand Up @@ -53,11 +55,41 @@ func TestLoadConfig(t *testing.T) {

func TestRunChecks(t *testing.T) {
b := getMockBench()
err := runChecks(b)
err := runChecks(b, "Server")
var write bytes.Buffer
outputWriter = &write
if err != nil {
t.Errorf("unexpected error: %s\n", err)
}
assert.NoError(t, err)
}

func TestUpdateControl(t *testing.T) {
here, _ := os.Getwd()
// cfgDir is defined in root.go
type TestCase struct {
version string
cfgPath string
want string
}

testCases := []TestCase{
{
version: "2.0.0",
cfgPath: fmt.Sprintf("%s/../cfg", here),
want: "cfg/2.0.0/definitions.yaml",
},
}
for _, tc := range testCases {
cfgDir = tc.cfgPath
config, err := loadConfig(tc.version)
assert.NoError(t, err)
f, err := os.ReadFile(config)
assert.NoError(t, err)
var c check.Controls
err = yaml.Unmarshal(f, &c)
assert.NoError(t, err)
got := updateControlCheck(&c, "DomainController")
assert.Equal(t, got.Groups[0].Checks[0].Audit.(string), "Get-ADDefaultDomainPasswordPolicy -Current LocalComputer | Select -ExpandProperty PasswordHistoryCount")
}
}
10 changes: 5 additions & 5 deletions shell/powershell.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ var memberServerRoles = []string{
type PowerShell struct {
Cmd map[string]string
sh ps.Shell
osType string
OsType string
}

type shellStarter interface {
Expand Down Expand Up @@ -80,9 +80,9 @@ func NewPowerShell() (*PowerShell, error) {
if err != nil {
return nil, fmt.Errorf("Failed to get operating system type: %w", err)
}
p.osType = osType
p.OsType = osType
if osType == "Server" {
p.osType, err = getServerType(p)
p.OsType, err = getServerType(p)
if err != nil {
return nil, fmt.Errorf("failed to get server type: %w", err)
}
Expand Down Expand Up @@ -156,9 +156,9 @@ func (p *PowerShell) executeCommand() (string, error) {
}

func (p *PowerShell) commandForRuntimeOS() (string, error) {
cmd, found := p.Cmd[p.osType]
cmd, found := p.Cmd[p.OsType]
if !found {
return "", errors.Wrap(errWrongOSType, fmt.Sprintf("Unable to find matching command for OS Type: %q", p.osType))
return "", errors.Wrap(errWrongOSType, fmt.Sprintf("Unable to find matching command for OS Type: %q", p.OsType))
}
return cmd, nil
}
Expand Down
4 changes: 2 additions & 2 deletions shell/powershell_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func TestExecute(t *testing.T) {
osTypeCmd: testSpace + testPShellCommand + testSpace, // surrounded by spaces
},
sh: &mockShell{},
osType: osTypeCmd,
OsType: osTypeCmd,
},
expectedResult: testPShellCommand,
fail: false,
Expand All @@ -157,7 +157,7 @@ func TestExecute(t *testing.T) {
osTypeCmd: testSpace + testPShellCommand + testNewLine, // starts with space end with new lines
},
sh: &mockShell{},
osType: osTypeCmd,
OsType: osTypeCmd,
},
expectedResult: testPShellCommand,
fail: false,
Expand Down

0 comments on commit e1554f2

Please sign in to comment.