Skip to content

Commit

Permalink
Improve testability of main/run function
Browse files Browse the repository at this point in the history
  • Loading branch information
dvob committed Jul 27, 2024
1 parent 23e9d0f commit 19bae74
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
32 changes: 18 additions & 14 deletions cmd/pcert/create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,32 @@ import (
"github.com/dvob/pcert"
)

func runCmd(args []string, env map[string]string) (io.WriteCloser, *bytes.Buffer, *bytes.Buffer, error) {
func runCmd(args []string, stdin io.Reader, env map[string]string) (*bytes.Buffer, *bytes.Buffer, error) {

stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}
stdinReader, stdinWriter := io.Pipe()
cmd := newRootCmd()
cmd.SetArgs(args)
cmd.SetIn(stdinReader)
cmd.SetOut(stdout)
cmd.SetErr(stderr)
cmd = WithEnv(cmd, args, func(name string) (string, bool) {

if stdin == nil {
stdin = bytes.NewReader(nil)
}

code := run(args, stdin, stdout, stderr, func(key string) (string, bool) {
if env == nil {
return "", false
}
val, ok := env[name]
val, ok := env[key]
return val, ok
})

return stdinWriter, stdout, stderr, cmd.Execute()
if code != 0 {
return stdout, stderr, fmt.Errorf("execution failed. stderr='%s'", stderr.String())
}

return stdout, stderr, nil
}

func runAndLoad(args []string, env map[string]string) (*x509.Certificate, error) {
_, stdout, stderr, err := runCmd(args, env)
stdout, stderr, err := runCmd(args, nil, env)
if err != nil {
return nil, err
}
Expand All @@ -45,7 +49,7 @@ func runAndLoad(args []string, env map[string]string) (*x509.Certificate, error)

cert, err := pcert.Parse(stdout.Bytes())
if err != nil {
return nil, fmt.Errorf("could not read certificate from standard output: %s", err)
return nil, fmt.Errorf("could not read certificate from standard output: %s. stdout='%s'", err, stdout.String())
}

return cert, err
Expand Down Expand Up @@ -215,10 +219,10 @@ func Test_create_not_before_with_expiry(t *testing.T) {
func Test_create_output_parameter(t *testing.T) {
defer os.Remove("tls.crt")
defer os.Remove("tls.key")
_, _, _, err := runCmd([]string{
_, _, err := runCmd([]string{
"create",
"tls.crt",
}, nil)
}, nil, nil)
if err != nil {
t.Fatal(err)
return
Expand Down
23 changes: 20 additions & 3 deletions cmd/pcert/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"
"io"
"os"

"github.com/spf13/cobra"
Expand All @@ -20,11 +21,27 @@ var (
)

func main() {
err := WithEnv(newRootCmd(), os.Args[1:], os.LookupEnv).Execute()
code := run(os.Args[1:], os.Stdin, os.Stdout, os.Stderr, os.LookupEnv)
os.Exit(code)
}

func run(args []string, stdin io.Reader, stdout io.Writer, stderr io.Writer, getEnv func(string) (string, bool)) int {
rootCmd := newRootCmd()

rootCmd.SetOut(stdout)
rootCmd.SetErr(stderr)
rootCmd.SetIn(stdin)

rootCmd = WithEnv(rootCmd, args, getEnv)
rootCmd.SetArgs(args)

err := rootCmd.Execute()
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
fmt.Fprintln(stderr, err)
return 1
}

return 0
}

func newRootCmd() *cobra.Command {
Expand Down
4 changes: 2 additions & 2 deletions cmd/pcert/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ import (

func Test_request(t *testing.T) {
name := "foo"
_, stdout, stderr, err := runCmd([]string{
stdout, stderr, err := runCmd([]string{
"request",
"--subject",
"/CN=" + name,
}, nil)
}, nil, nil)
if err != nil {
t.Fatal(err)
return
Expand Down

0 comments on commit 19bae74

Please sign in to comment.