Skip to content

Commit

Permalink
Implement singleton pattern for NewRootCmd to prevent multiple creati…
Browse files Browse the repository at this point in the history
…ons (#840)

* Only allow creating the NewRootCmd once with singleton pattern
* Fix unit tests by using NewRootCmdForTest

Signed-off-by: Anuj Chaudhari <[email protected]>
  • Loading branch information
anujc25 committed Jan 14, 2025
1 parent 18b6706 commit 2149273
Show file tree
Hide file tree
Showing 16 changed files with 67 additions and 44 deletions.
2 changes: 1 addition & 1 deletion pkg/command/ceip_participation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func TestCompletionCeip(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
2 changes: 1 addition & 1 deletion pkg/command/cert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ func TestCompletionCert(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
2 changes: 1 addition & 1 deletion pkg/command/completion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func TestCompletionCompletion(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
2 changes: 1 addition & 1 deletion pkg/command/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func TestCompletionConfig(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
4 changes: 2 additions & 2 deletions pkg/command/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1667,7 +1667,7 @@ func TestCompletionContext(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down Expand Up @@ -1942,7 +1942,7 @@ func TestContextCurrentCmd(t *testing.T) {
_ = config.SetContext(spec.activeContexts[i], true)
}

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
10 changes: 5 additions & 5 deletions pkg/command/discovery_source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func Test_createAndListDiscoverySources(t *testing.T) {
testSource1 := "harbor-repo.vmware.com/tanzu_cli_stage/plugins/plugin-inventory:latest"
os.Setenv(constants.ConfigVariableAdditionalDiscoveryForTesting, testSource1)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)
rootCmd.SetArgs([]string{"plugin", "source", "list"})
b := bytes.NewBufferString("")
Expand Down Expand Up @@ -221,7 +221,7 @@ func Test_initDiscoverySources(t *testing.T) {
}})
assert.Nil(err)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)
rootCmd.SetArgs(spec.args)
b := bytes.NewBufferString("")
Expand Down Expand Up @@ -336,7 +336,7 @@ func Test_updateDiscoverySources(t *testing.T) {
}})
assert.Nil(err)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)
rootCmd.SetArgs(spec.args)
b := bytes.NewBufferString("")
Expand Down Expand Up @@ -449,7 +449,7 @@ func Test_deleteDiscoverySource(t *testing.T) {
}})
assert.Nil(err)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)
rootCmd.SetArgs(spec.args)
b := bytes.NewBufferString("")
Expand Down Expand Up @@ -576,7 +576,7 @@ func TestCompletionPluginSource(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
4 changes: 2 additions & 2 deletions pkg/command/doc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestGenDocs(t *testing.T) {
assert.Nil(err)
defer os.RemoveAll(docsDir)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)
rootCmd.SetArgs([]string{"generate-all-docs", "--docs-dir", docsDir})
err = rootCmd.Execute()
Expand Down Expand Up @@ -124,7 +124,7 @@ func TestCompletionGenerateDocs(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
8 changes: 4 additions & 4 deletions pkg/command/eula_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ var _ = Describe("EULA command tests", func() {
Context("When invoking an arbitrary command", func() {
It("should invoke the eula prompt if EULA has not been accepted", func() {
os.Setenv("TANZU_CLI_EULA_PROMPT_ANSWER", "No")
cmd, err := NewRootCmd()
cmd, err := NewRootCmdForTest()
Expect(err).To(BeNil())
cmd.SetArgs([]string{"context", "list"})
err = cmd.Execute()
Expand All @@ -113,7 +113,7 @@ var _ = Describe("EULA command tests", func() {

It("should not invoke the eula prompt if EULA has been accepted", func() {
os.Setenv("TANZU_CLI_EULA_PROMPT_ANSWER", "yes")
cmd, err := NewRootCmd()
cmd, err := NewRootCmdForTest()
Expect(err).To(BeNil())
cmd.SetArgs([]string{"context", "list"})
err = cmd.Execute()
Expand Down Expand Up @@ -158,7 +158,7 @@ var _ = Describe("EULA version checking tests", func() {
err = config.SetEULAStatus(config.EULAStatusAccepted)
Expect(err).To(BeNil())

cmd, err := NewRootCmd()
cmd, err := NewRootCmdForTest()
Expect(err).To(BeNil())
cmd.SetArgs([]string{"context", "list"})
checkForPromptOnExecute(cmd, expectToPrompt)
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestCompletionEULA(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
2 changes: 1 addition & 1 deletion pkg/command/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestCompletionInit(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
2 changes: 1 addition & 1 deletion pkg/command/plugin_bundle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func TestCompletionPluginBundle(t *testing.T) {

downloadImageCalled = false

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
6 changes: 3 additions & 3 deletions pkg/command/plugin_group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func TestPluginGroupSearch(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down Expand Up @@ -191,7 +191,7 @@ func TestPluginGroupGet(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down Expand Up @@ -306,7 +306,7 @@ func TestCompletionPluginGroup(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
4 changes: 2 additions & 2 deletions pkg/command/plugin_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestPluginSearch(t *testing.T) {

for _, spec := range tests {
t.Run(spec.test, func(t *testing.T) {
rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)
rootCmd.SetArgs(spec.args)

Expand Down Expand Up @@ -160,7 +160,7 @@ func TestCompletionPluginSearch(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
10 changes: 5 additions & 5 deletions pkg/command/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func TestPluginList(t *testing.T) {
}
cc.Unlock()

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)
rootCmd.SetArgs(spec.args)
b := bytes.NewBufferString("")
Expand Down Expand Up @@ -280,7 +280,7 @@ func TestDeletePlugin(t *testing.T) {
}
cupdater.Unlock()

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)
rootCmd.SetArgs(spec.args)

Expand Down Expand Up @@ -393,7 +393,7 @@ func TestInstallPlugin(t *testing.T) {

for _, spec := range tests {
t.Run(spec.test, func(t *testing.T) {
rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)
rootCmd.SetArgs(spec.args)

Expand Down Expand Up @@ -444,7 +444,7 @@ func TestUpgradePlugin(t *testing.T) {

for _, spec := range tests {
t.Run(spec.test, func(t *testing.T) {
rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)
rootCmd.SetArgs(spec.args)

Expand Down Expand Up @@ -947,7 +947,7 @@ func TestCompletionPlugin(t *testing.T) {
t.Run(spec.test, func(t *testing.T) {
assert := assert.New(t)

rootCmd, err := NewRootCmd()
rootCmd, err := NewRootCmdForTest()
assert.Nil(err)

var out bytes.Buffer
Expand Down
31 changes: 27 additions & 4 deletions pkg/command/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"

Expand Down Expand Up @@ -80,8 +81,9 @@ func convertInvokedAs(plugins []cli.PluginInfo) {
}
}

// NewRootCmd creates a root command.
func NewRootCmd() (*cobra.Command, error) { //nolint: gocyclo
// createRootCmd creates a root command.
// NOTE: Do not use this function directly as it will bypass the locking. Use `NewRootCmd` instead
func createRootCmd() (*cobra.Command, error) { //nolint: gocyclo
go interruptHandle()
var rootCmd = newRootCmd()
uFunc := cli.NewMainUsage().UsageFunc()
Expand Down Expand Up @@ -816,13 +818,28 @@ func shouldSkipGlobalInit(cmd *cobra.Command) bool {
return isSkipCommand(skipGlobalInitCommands, cmd.CommandPath())
}

var globalRootCmd *cobra.Command
var globalRootCmdLock = &sync.Mutex{}

// NewRootCmd create a new root command instance if it was not created before
func NewRootCmd() (*cobra.Command, error) {
globalRootCmdLock.Lock()
defer globalRootCmdLock.Unlock()

var err error
if globalRootCmd == nil {
globalRootCmd, err = createRootCmd()
}
return globalRootCmd, err
}

// Execute executes the CLI.
func Execute() error {
root, err := NewRootCmd()
rootCmd, err := NewRootCmd()
if err != nil {
return err
}
executionErr := root.Execute()
executionErr := rootCmd.Execute()
exitCode := 0
if executionErr != nil {
exitCode = 1
Expand Down Expand Up @@ -909,3 +926,9 @@ func printShortDescOfCmdInActiveHelp(cmd *cobra.Command, args []string) {
}
}
}

// NewRootCmdForTest creates a new instance of the root command for unit test purpose
// Note: This must not be used as part of the production code and only used for unit tests
func NewRootCmdForTest() (*cobra.Command, error) {
return createRootCmd()
}
Loading

0 comments on commit 2149273

Please sign in to comment.