From 21492737ddfb9fb6597201fb9c5b235975c8f49e Mon Sep 17 00:00:00 2001 From: Anuj Chaudhari Date: Tue, 14 Jan 2025 10:55:23 -0800 Subject: [PATCH] Implement singleton pattern for NewRootCmd to prevent multiple creations (#840) * Only allow creating the NewRootCmd once with singleton pattern * Fix unit tests by using NewRootCmdForTest Signed-off-by: Anuj Chaudhari --- pkg/command/ceip_participation_test.go | 2 +- pkg/command/cert_test.go | 2 +- pkg/command/completion_test.go | 2 +- pkg/command/config_test.go | 2 +- pkg/command/context_test.go | 4 ++-- pkg/command/discovery_source_test.go | 10 ++++----- pkg/command/doc_test.go | 4 ++-- pkg/command/eula_test.go | 8 +++---- pkg/command/init_test.go | 2 +- pkg/command/plugin_bundle_test.go | 2 +- pkg/command/plugin_group_test.go | 6 ++--- pkg/command/plugin_search_test.go | 4 ++-- pkg/command/plugin_test.go | 10 ++++----- pkg/command/root.go | 31 ++++++++++++++++++++++---- pkg/command/root_test.go | 20 ++++++++--------- pkg/command/version_test.go | 2 +- 16 files changed, 67 insertions(+), 44 deletions(-) diff --git a/pkg/command/ceip_participation_test.go b/pkg/command/ceip_participation_test.go index 0e9b29e24..34547fd20 100644 --- a/pkg/command/ceip_participation_test.go +++ b/pkg/command/ceip_participation_test.go @@ -158,7 +158,7 @@ func TestCompletionCeip(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/cert_test.go b/pkg/command/cert_test.go index 9cbf01a1a..d09422c27 100644 --- a/pkg/command/cert_test.go +++ b/pkg/command/cert_test.go @@ -464,7 +464,7 @@ func TestCompletionCert(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/completion_test.go b/pkg/command/completion_test.go index bb6638b7b..00a353652 100644 --- a/pkg/command/completion_test.go +++ b/pkg/command/completion_test.go @@ -155,7 +155,7 @@ func TestCompletionCompletion(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/config_test.go b/pkg/command/config_test.go index 5d17a31e2..ec08615e6 100644 --- a/pkg/command/config_test.go +++ b/pkg/command/config_test.go @@ -220,7 +220,7 @@ func TestCompletionConfig(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/context_test.go b/pkg/command/context_test.go index 9bb30a419..58cc51cec 100644 --- a/pkg/command/context_test.go +++ b/pkg/command/context_test.go @@ -1667,7 +1667,7 @@ func TestCompletionContext(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer @@ -1942,7 +1942,7 @@ func TestContextCurrentCmd(t *testing.T) { _ = config.SetContext(spec.activeContexts[i], true) } - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/discovery_source_test.go b/pkg/command/discovery_source_test.go index 38db5955b..37440ee40 100644 --- a/pkg/command/discovery_source_test.go +++ b/pkg/command/discovery_source_test.go @@ -124,7 +124,7 @@ func Test_createAndListDiscoverySources(t *testing.T) { testSource1 := "harbor-repo.vmware.com/tanzu_cli_stage/plugins/plugin-inventory:latest" os.Setenv(constants.ConfigVariableAdditionalDiscoveryForTesting, testSource1) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs([]string{"plugin", "source", "list"}) b := bytes.NewBufferString("") @@ -221,7 +221,7 @@ func Test_initDiscoverySources(t *testing.T) { }}) assert.Nil(err) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) b := bytes.NewBufferString("") @@ -336,7 +336,7 @@ func Test_updateDiscoverySources(t *testing.T) { }}) assert.Nil(err) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) b := bytes.NewBufferString("") @@ -449,7 +449,7 @@ func Test_deleteDiscoverySource(t *testing.T) { }}) assert.Nil(err) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) b := bytes.NewBufferString("") @@ -576,7 +576,7 @@ func TestCompletionPluginSource(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/doc_test.go b/pkg/command/doc_test.go index 3162775d2..7e1585d85 100644 --- a/pkg/command/doc_test.go +++ b/pkg/command/doc_test.go @@ -67,7 +67,7 @@ func TestGenDocs(t *testing.T) { assert.Nil(err) defer os.RemoveAll(docsDir) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs([]string{"generate-all-docs", "--docs-dir", docsDir}) err = rootCmd.Execute() @@ -124,7 +124,7 @@ func TestCompletionGenerateDocs(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/eula_test.go b/pkg/command/eula_test.go index 04a1a648c..6a4e66911 100644 --- a/pkg/command/eula_test.go +++ b/pkg/command/eula_test.go @@ -102,7 +102,7 @@ var _ = Describe("EULA command tests", func() { Context("When invoking an arbitrary command", func() { It("should invoke the eula prompt if EULA has not been accepted", func() { os.Setenv("TANZU_CLI_EULA_PROMPT_ANSWER", "No") - cmd, err := NewRootCmd() + cmd, err := NewRootCmdForTest() Expect(err).To(BeNil()) cmd.SetArgs([]string{"context", "list"}) err = cmd.Execute() @@ -113,7 +113,7 @@ var _ = Describe("EULA command tests", func() { It("should not invoke the eula prompt if EULA has been accepted", func() { os.Setenv("TANZU_CLI_EULA_PROMPT_ANSWER", "yes") - cmd, err := NewRootCmd() + cmd, err := NewRootCmdForTest() Expect(err).To(BeNil()) cmd.SetArgs([]string{"context", "list"}) err = cmd.Execute() @@ -158,7 +158,7 @@ var _ = Describe("EULA version checking tests", func() { err = config.SetEULAStatus(config.EULAStatusAccepted) Expect(err).To(BeNil()) - cmd, err := NewRootCmd() + cmd, err := NewRootCmdForTest() Expect(err).To(BeNil()) cmd.SetArgs([]string{"context", "list"}) checkForPromptOnExecute(cmd, expectToPrompt) @@ -214,7 +214,7 @@ func TestCompletionEULA(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/init_test.go b/pkg/command/init_test.go index 1afcf80ee..117dd3923 100644 --- a/pkg/command/init_test.go +++ b/pkg/command/init_test.go @@ -32,7 +32,7 @@ func TestCompletionInit(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/plugin_bundle_test.go b/pkg/command/plugin_bundle_test.go index 15c3a314d..0d9e5ae45 100644 --- a/pkg/command/plugin_bundle_test.go +++ b/pkg/command/plugin_bundle_test.go @@ -231,7 +231,7 @@ func TestCompletionPluginBundle(t *testing.T) { downloadImageCalled = false - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/plugin_group_test.go b/pkg/command/plugin_group_test.go index fb0a9e9b0..4220c56f9 100644 --- a/pkg/command/plugin_group_test.go +++ b/pkg/command/plugin_group_test.go @@ -87,7 +87,7 @@ func TestPluginGroupSearch(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer @@ -191,7 +191,7 @@ func TestPluginGroupGet(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer @@ -306,7 +306,7 @@ func TestCompletionPluginGroup(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/plugin_search_test.go b/pkg/command/plugin_search_test.go index 8b3a9c354..d50f7bc52 100644 --- a/pkg/command/plugin_search_test.go +++ b/pkg/command/plugin_search_test.go @@ -69,7 +69,7 @@ func TestPluginSearch(t *testing.T) { for _, spec := range tests { t.Run(spec.test, func(t *testing.T) { - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) @@ -160,7 +160,7 @@ func TestCompletionPluginSearch(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/plugin_test.go b/pkg/command/plugin_test.go index 205905445..d87b1c7de 100644 --- a/pkg/command/plugin_test.go +++ b/pkg/command/plugin_test.go @@ -133,7 +133,7 @@ func TestPluginList(t *testing.T) { } cc.Unlock() - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) b := bytes.NewBufferString("") @@ -280,7 +280,7 @@ func TestDeletePlugin(t *testing.T) { } cupdater.Unlock() - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) @@ -393,7 +393,7 @@ func TestInstallPlugin(t *testing.T) { for _, spec := range tests { t.Run(spec.test, func(t *testing.T) { - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) @@ -444,7 +444,7 @@ func TestUpgradePlugin(t *testing.T) { for _, spec := range tests { t.Run(spec.test, func(t *testing.T) { - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) @@ -947,7 +947,7 @@ func TestCompletionPlugin(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer diff --git a/pkg/command/root.go b/pkg/command/root.go index cd4b65d16..066f2bba6 100644 --- a/pkg/command/root.go +++ b/pkg/command/root.go @@ -12,6 +12,7 @@ import ( "os/signal" "strconv" "strings" + "sync" "syscall" "time" @@ -80,8 +81,9 @@ func convertInvokedAs(plugins []cli.PluginInfo) { } } -// NewRootCmd creates a root command. -func NewRootCmd() (*cobra.Command, error) { //nolint: gocyclo +// createRootCmd creates a root command. +// NOTE: Do not use this function directly as it will bypass the locking. Use `NewRootCmd` instead +func createRootCmd() (*cobra.Command, error) { //nolint: gocyclo go interruptHandle() var rootCmd = newRootCmd() uFunc := cli.NewMainUsage().UsageFunc() @@ -816,13 +818,28 @@ func shouldSkipGlobalInit(cmd *cobra.Command) bool { return isSkipCommand(skipGlobalInitCommands, cmd.CommandPath()) } +var globalRootCmd *cobra.Command +var globalRootCmdLock = &sync.Mutex{} + +// NewRootCmd create a new root command instance if it was not created before +func NewRootCmd() (*cobra.Command, error) { + globalRootCmdLock.Lock() + defer globalRootCmdLock.Unlock() + + var err error + if globalRootCmd == nil { + globalRootCmd, err = createRootCmd() + } + return globalRootCmd, err +} + // Execute executes the CLI. func Execute() error { - root, err := NewRootCmd() + rootCmd, err := NewRootCmd() if err != nil { return err } - executionErr := root.Execute() + executionErr := rootCmd.Execute() exitCode := 0 if executionErr != nil { exitCode = 1 @@ -909,3 +926,9 @@ func printShortDescOfCmdInActiveHelp(cmd *cobra.Command, args []string) { } } } + +// NewRootCmdForTest creates a new instance of the root command for unit test purpose +// Note: This must not be used as part of the production code and only used for unit tests +func NewRootCmdForTest() (*cobra.Command, error) { + return createRootCmd() +} diff --git a/pkg/command/root_test.go b/pkg/command/root_test.go index 16331624a..3e5751142 100644 --- a/pkg/command/root_test.go +++ b/pkg/command/root_test.go @@ -180,7 +180,7 @@ func tearDownTestCLIEnvironment(env testCLIEnvironment) { func TestRootCmdWithNoAdditionalPlugins(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) err = rootCmd.Execute() assert.Nil(err) @@ -188,7 +188,7 @@ func TestRootCmdWithNoAdditionalPlugins(t *testing.T) { func TestSubcommandNonexistent(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs([]string{"nonexistent", "say", "hello"}) err = rootCmd.Execute() @@ -415,7 +415,7 @@ func TestSubcommands(t *testing.T) { cc.Unlock() assert.Nil(err) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) @@ -799,7 +799,7 @@ func TestTargetCommands(t *testing.T) { assert.Nil(err) } - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) @@ -895,7 +895,7 @@ func TestGlobalInit(t *testing.T) { }, ) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) rootCmd.SetArgs(spec.args) @@ -943,7 +943,7 @@ func TestSetLastVersion(t *testing.T) { buildinfo.Version = spec.version - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) // Execute any command to trigger the version update rootCmd.SetArgs([]string{"plugin", "list"}) @@ -967,7 +967,7 @@ func TestEnvVarsSet(t *testing.T) { err := config.ConfigureFeatureFlags(constants.DefaultCliFeatureFlags) assert.Nil(err) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) envVarName := "SOME_TEST_ENV_VAR" @@ -983,7 +983,7 @@ func TestEnvVarsSet(t *testing.T) { // Re-initialize the CLI with the config files containing the variable. // It is in this call that the CLI creates the OS variables. - _, err = NewRootCmd() + _, err = NewRootCmdForTest() assert.Nil(err) // Make sure the variable is now set during the call to the CLI assert.Equal(envVarValue, os.Getenv(envVarName)) @@ -1051,7 +1051,7 @@ func TestCompletionShortHelpInActiveHelp(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer @@ -2213,7 +2213,7 @@ func TestCommandRemapping(t *testing.T) { assert.Nil(err) } - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) // To be able to test the "tanzu help" command, we need to set the os.Args // instead of using the cobra command's SetArgs method. diff --git a/pkg/command/version_test.go b/pkg/command/version_test.go index 6730e8b32..a0409a83c 100644 --- a/pkg/command/version_test.go +++ b/pkg/command/version_test.go @@ -86,7 +86,7 @@ func TestCompletionVersion(t *testing.T) { t.Run(spec.test, func(t *testing.T) { assert := assert.New(t) - rootCmd, err := NewRootCmd() + rootCmd, err := NewRootCmdForTest() assert.Nil(err) var out bytes.Buffer