diff --git a/app.go b/app.go index 9bea11efde..3fb5c169ad 100644 --- a/app.go +++ b/app.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "sort" + "strings" "time" ) @@ -20,6 +21,7 @@ var ( errInvalidActionType = NewExitError("ERROR invalid Action type. "+ fmt.Sprintf("Must be `func(*Context`)` or `func(*Context) error). %s", contactSysadmin)+ fmt.Sprintf("See %s", appActionDeprecationURL), 2) + ignoreFlagPrefix = "test." // this is to ignore test flags when adding flags from other packages SuggestFlag SuggestFlagFunc = suggestFlag SuggestCommand SuggestCommandFunc = suggestCommand @@ -197,6 +199,14 @@ func (a *App) Setup() { a.ErrWriter = os.Stderr } + // add global flags added by other packages + flag.VisitAll(func(f *flag.Flag) { + // skip test flags + if !strings.HasPrefix(f.Name, ignoreFlagPrefix) { + a.Flags = append(a.Flags, &extFlag{f}) + } + }) + var newCommands []*Command for _, c := range a.Commands { diff --git a/app_test.go b/app_test.go index 9b1a347583..d24fb63503 100644 --- a/app_test.go +++ b/app_test.go @@ -643,6 +643,42 @@ func TestApp_RunDefaultCommandWithFlags(t *testing.T) { } } +func TestApp_FlagsFromExtPackage(t *testing.T) { + + var someint int + flag.IntVar(&someint, "epflag", 2, "ext package flag usage") + + // Based on source code we can reset the global flag parsing this way + defer func() { + flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) + }() + + a := &App{ + Flags: []Flag{ + &StringFlag{ + Name: "carly", + Aliases: []string{"c"}, + Required: false, + }, + &BoolFlag{ + Name: "jimbob", + Aliases: []string{"j"}, + Required: false, + Value: true, + }, + }, + } + + err := a.Run([]string{"foo", "-c", "cly", "--epflag", "10"}) + if err != nil { + t.Error(err) + } + + if someint != 10 { + t.Errorf("Expected 10 got %d for someint", someint) + } +} + func TestApp_Setup_defaultsReader(t *testing.T) { app := &App{} app.Setup() diff --git a/errors.go b/errors.go index 225e1bb378..a818727dbb 100644 --- a/errors.go +++ b/errors.go @@ -83,7 +83,7 @@ type ExitCoder interface { type exitError struct { exitCode int - message interface{} + err error } // NewExitError calls Exit to create a new ExitCoder. @@ -101,20 +101,35 @@ func NewExitError(message interface{}, exitCode int) ExitCoder { // by overriding the ExitErrHandler function on an App or the package-global // OsExiter function. func Exit(message interface{}, exitCode int) ExitCoder { + var err error + + switch e := message.(type) { + case ErrorFormatter: + err = fmt.Errorf("%+v", message) + case error: + err = e + default: + err = fmt.Errorf("%+v", message) + } + return &exitError{ - message: message, + err: err, exitCode: exitCode, } } func (ee *exitError) Error() string { - return fmt.Sprintf("%v", ee.message) + return ee.err.Error() } func (ee *exitError) ExitCode() int { return ee.exitCode } +func (ee *exitError) Unwrap() error { + return ee.err +} + // HandleExitCoder handles errors implementing ExitCoder by printing their // message and calling OsExiter with the given exit code. // diff --git a/errors_test.go b/errors_test.go index d0b1b4fb13..337009c809 100644 --- a/errors_test.go +++ b/errors_test.go @@ -45,6 +45,25 @@ func TestHandleExitCoder_ExitCoder(t *testing.T) { expect(t, called, true) } +func TestHandleExitCoder_ErrorExitCoder(t *testing.T) { + exitCode := 0 + called := false + + OsExiter = func(rc int) { + if !called { + exitCode = rc + called = true + } + } + + defer func() { OsExiter = fakeOsExiter }() + + HandleExitCoder(Exit(errors.New("galactic perimeter breach"), 9)) + + expect(t, exitCode, 9) + expect(t, called, true) +} + func TestHandleExitCoder_MultiErrorWithExitCoder(t *testing.T) { exitCode := 0 called := false diff --git a/flag_ext.go b/flag_ext.go new file mode 100644 index 0000000000..64da59ea93 --- /dev/null +++ b/flag_ext.go @@ -0,0 +1,48 @@ +package cli + +import "flag" + +type extFlag struct { + f *flag.Flag +} + +func (e *extFlag) Apply(fs *flag.FlagSet) error { + fs.Var(e.f.Value, e.f.Name, e.f.Usage) + return nil +} + +func (e *extFlag) Names() []string { + return []string{e.f.Name} +} + +func (e *extFlag) IsSet() bool { + return false +} + +func (e *extFlag) String() string { + return FlagStringer(e) +} + +func (e *extFlag) IsVisible() bool { + return true +} + +func (e *extFlag) TakesValue() bool { + return false +} + +func (e *extFlag) GetUsage() string { + return e.f.Usage +} + +func (e *extFlag) GetValue() string { + return e.f.Value.String() +} + +func (e *extFlag) GetDefaultText() string { + return e.f.DefValue +} + +func (e *extFlag) GetEnvVars() []string { + return nil +}