diff --git a/auth.go b/auth.go index c5003b0..a92bcab 100644 --- a/auth.go +++ b/auth.go @@ -34,16 +34,18 @@ type PersistentStore interface { // This is not immediately obvious from the documentation. // See https://developer.tdameritrade.com/content/authentication-faq type Authenticator struct { - Store PersistentStore - OAuth2 oauth2.Config + Store PersistentStore + OAuth2 oauth2.Config + AuthOpts []oauth2.AuthCodeOption } // NewAuthenticator will automatically append @AMER.OAUTHAP to the client ID to save callers hours of frustration. -func NewAuthenticator(store PersistentStore, oauth2 oauth2.Config) *Authenticator { +func NewAuthenticator(store PersistentStore, oauth2 oauth2.Config, opts ...oauth2.AuthCodeOption) *Authenticator { oauth2.ClientID = oauth2.ClientID + "@AMER.OAUTHAP" return &Authenticator{ - Store: store, - OAuth2: oauth2, + Store: store, + OAuth2: oauth2, + AuthOpts: opts, } } @@ -75,7 +77,7 @@ func (a *Authenticator) StartOAuth2Flow(w http.ResponseWriter, req *http.Request return "", err } - return a.OAuth2.AuthCodeURL(state), nil + return a.OAuth2.AuthCodeURL(state, a.AuthOpts...), nil } // FinishOAuth2Flow finishes authenticating a user returning from TD Ameritrade. @@ -104,7 +106,7 @@ func (a *Authenticator) FinishOAuth2Flow(ctx context.Context, w http.ResponseWri if state[0] != expectedState { return nil, fmt.Errorf("invalid state. expected: '%v', got '%v'", expectedState, state[0]) } - token, err := a.OAuth2.Exchange(ctx, code[0]) + token, err := a.OAuth2.Exchange(ctx, code[0], a.AuthOpts...) if err != nil { return nil, err }