Skip to content

Commit

Permalink
feat: use random port for OAuth2 callback (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr authored Aug 9, 2024
1 parent ee938b2 commit aaefcc6
Showing 1 changed file with 15 additions and 43 deletions.
58 changes: 15 additions & 43 deletions cmd/cloudx/client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
stderrors "errors"
"fmt"
"io"
"math/rand/v2"
"net"
"net/http"
"net/url"
Expand All @@ -17,9 +16,7 @@ import (
"time"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
"golang.org/x/oauth2"
"golang.org/x/sync/errgroup"

cloud "github.com/ory/client-go"
"github.com/ory/x/randx"
Expand Down Expand Up @@ -178,27 +175,17 @@ func (h *CommandHelper) loginOAuth2(ctx context.Context) (*Config, error) {

func (h *CommandHelper) oAuth2DanceWithServer(ctx context.Context, client *oauth2.Config) (token *oauth2.Token, err error) {
var (
l net.Listener
state = randx.MustString(32, randx.AlphaNum)
pkceVerifier = oauth2.GenerateVerifier()
ports = []int{12345, 15793, 17628, 19834, 23730, 27462, 34525, 36209, 42827, 46718, 49763, 51238, 52213, 57923, 59724, 60582, 62125, 65321, 49876, 54321, 59876, 60987, 62345, 63456, 64567, 65123, 65234, 65432, 65500, 65510, 65520, 65530}
serverErr = make(chan error)
serverToken = make(chan *oauth2.Token)
)
rand.Shuffle(len(ports), func(i, j int) { ports[i], ports[j] = ports[j], ports[i] })
for _, port := range ports {
l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
if err == nil {
client.RedirectURL = fmt.Sprintf("http://localhost:%d/callback", port)
break
}
}
if l == nil {
return nil, fmt.Errorf("failed to allocate port for OAuth2 callback handler, try again later: last error: %w", err)
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, fmt.Errorf("failed to allocate port for OAuth2 callback handler, try again later: %w", err)
}
client.RedirectURL = fmt.Sprintf("http://%s/callback", l.Addr().String())

var (
serverErr = make(chan error)
serverToken = make(chan *oauth2.Token)
)
srv := http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// for retries the user has to start from the beginning
Expand Down Expand Up @@ -242,27 +229,8 @@ func (h *CommandHelper) oAuth2DanceWithServer(ctx context.Context, client *oauth
redirectOK(w, r)
}),
}
ctx, cancel := context.WithCancel(ctx)
defer cancel()

eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() (err error) {
if err := srv.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to serve OAuth2 callback handler: %w", err)
}
return nil
})
eg.Go(func() (err error) {
select {
case <-ctx.Done():
err = ctx.Err()
case token = <-serverToken:
case err = <-serverErr:
}
ctx, cancel := context.WithDeadline(context.WithoutCancel(ctx), time.Now().Add(20*time.Second))
defer cancel()
return stderrors.Join(err, srv.Shutdown(ctx))
})
go func() { _ = srv.Serve(l) }()
defer srv.Close()

u := client.AuthCodeURL(state,
oauth2.S256ChallengeOption(pkceVerifier),
Expand All @@ -282,10 +250,14 @@ If no browser opened, visit the below page to continue:
`, u)

if err := eg.Wait(); err != nil {
return nil, fmt.Errorf("failed to authenticate, please try again: %w", err)
select {
case <-ctx.Done():
return nil, ctx.Err()
case token := <-serverToken:
return token, nil
case err := <-serverErr:
return nil, err
}
return token, nil
}

func redirectOK(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit aaefcc6

Please sign in to comment.