Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve goroutine handling, server shutdown #505

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/keycloak/proxy/oauth_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/gogatekeeper/gatekeeper/pkg/proxy/core"
"github.com/gogatekeeper/gatekeeper/pkg/storage"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
)

type PAT struct {
Expand All @@ -36,9 +37,12 @@ type OauthProxy struct {
Router http.Handler
adminRouter http.Handler
Server *http.Server
HTTPServer *http.Server
AdminServer *http.Server
Store storage.Storage
Upstream core.ReverseProxy
pat *PAT
rpt *RPT
Cm *cookie.Manager
ErrGroup *errgroup.Group
}
101 changes: 67 additions & 34 deletions pkg/keycloak/proxy/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"time"

"golang.org/x/net/http/httpproxy"
"golang.org/x/sync/errgroup"

"go.uber.org/zap/zapcore"

Expand Down Expand Up @@ -875,11 +876,10 @@ func (r *OauthProxy) createForwardingProxy() error {
// Run starts the proxy service
//
//nolint:cyclop
func (r *OauthProxy) Run() error {
func (r *OauthProxy) Run() (context.Context, error) {
listener, err := r.createHTTPListener(makeListenerConfig(r.Config))

if err != nil {
return err
return nil, err
}

// step: create the main http(s) server
Expand All @@ -894,18 +894,24 @@ func (r *OauthProxy) Run() error {
r.Server = server
r.Listener = listener

go func() {
r.Log.Info(
"Gatekeeper proxy service starting",
zap.String("interface", r.Config.Listen),
)
errGroup, ctx := errgroup.WithContext(context.Background())
r.ErrGroup = errGroup
r.ErrGroup.Go(
func() error {
r.Log.Info(
"Gatekeeper proxy service starting",
zap.String("interface", r.Config.Listen),
)

if err = server.Serve(listener); err != nil {
if err != http.ErrServerClosed {
r.Log.Fatal("failed to start the http service", zap.Error(err))
if err := server.Serve(listener); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
r.Log.Fatal("failed to start the http service", zap.Error(err))
return err
}
}
}
}()
return nil
},
)

// step: are we running http service as well?
if r.Config.ListenHTTP != "" {
Expand All @@ -918,9 +924,8 @@ func (r *OauthProxy) Run() error {
listen: r.Config.ListenHTTP,
proxyProtocol: r.Config.EnableProxyProtocol,
})

if err != nil {
return err
return nil, err
}

httpsvc := &http.Server{
Expand All @@ -931,18 +936,23 @@ func (r *OauthProxy) Run() error {
IdleTimeout: r.Config.ServerIdleTimeout,
}

go func() {
r.HTTPServer = httpsvc
r.ErrGroup.Go(func() error {
if err := httpsvc.Serve(httpListener); err != nil {
r.Log.Fatal("failed to start the http redirect service", zap.Error(err))
if !errors.Is(err, http.ErrServerClosed) {
r.Log.Error("failed to start the http redirect service", zap.Error(err))
return err
}
}
}()
return nil
})
}

// step: are we running specific admin service as well?
// if not, admin endpoints are added as routes in the main service
if r.Config.ListenAdmin != "" {
r.Log.Info(
"keycloak proxy admin service starting",
"Gatekeeper proxy admin service starting",
zap.String("interface", r.Config.ListenAdmin),
)

Expand All @@ -957,9 +967,8 @@ func (r *OauthProxy) Run() error {
listen: r.Config.ListenAdmin,
proxyProtocol: r.Config.EnableProxyProtocol,
})

if err != nil {
return err
return nil, err
}
} else {
adminListenerConfig := makeListenerConfig(r.Config)
Expand All @@ -973,18 +982,16 @@ func (r *OauthProxy) Run() error {
adminListenerConfig.certificate = r.Config.TLSAdminCertificate
adminListenerConfig.privateKey = r.Config.TLSAdminPrivateKey
}

if r.Config.TLSAdminCaCertificate != "" {
adminListenerConfig.ca = r.Config.TLSAdminCaCertificate
}

if r.Config.TLSAdminClientCertificate != "" {
adminListenerConfig.clientCert = r.Config.TLSAdminClientCertificate
}

adminListener, err = r.createHTTPListener(adminListenerConfig)
if err != nil {
return err
return nil, err
}
}

Expand All @@ -996,26 +1003,52 @@ func (r *OauthProxy) Run() error {
IdleTimeout: r.Config.ServerIdleTimeout,
}

go func() {
if ers := adminsvc.Serve(adminListener); err != nil {
r.Log.Fatal("failed to start the admin service", zap.Error(ers))
r.AdminServer = adminsvc
r.ErrGroup.Go(func() error {
if err := adminsvc.Serve(adminListener); err != nil {
if !errors.Is(err, http.ErrServerClosed) {
r.Log.Error("failed to start the admin service", zap.Error(err))
return err
}
}
}()
return nil
})
}

return nil
return ctx, nil
}

// Shutdown finishes the proxy service with gracefully period
func (r *OauthProxy) Shutdown() error {
ctx, cancel := context.WithTimeout(context.Background(), r.Config.ServerGraceTimeout)
ctx, cancel := context.WithTimeout(
context.Background(),
r.Config.ServerGraceTimeout,
)
defer cancel()

err := r.Server.Shutdown(ctx)
if err == nil {
return nil
var err error
servers := []*http.Server{
r.Server,
r.HTTPServer,
r.AdminServer,
}
for idx, srv := range servers {
if srv != nil {
r.Log.Debug("Shutdown http server", zap.Int("num", idx))
if errShut := srv.Shutdown(ctx); errShut != nil {
if closeErr := srv.Close(); closeErr != nil {
err = errors.Join(err, closeErr)
}
}
}
}

r.Log.Debug("Waiting for goroutines to finish")
if routineErr := r.ErrGroup.Wait(); routineErr != nil {
err = errors.Join(err, routineErr)
}
return r.Server.Close()

return err
}

// listenerConfig encapsulate listener options
Expand Down
25 changes: 21 additions & 4 deletions pkg/proxy/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
package proxy

import (
"context"
"errors"
"fmt"
"os"
"os/signal"
Expand All @@ -33,6 +35,8 @@ import (
)

// newOauthProxyApp creates a new cli application and runs it
//
//nolint:cyclop
func NewOauthProxyApp[T proxycore.KeycloakProvider | proxycore.GoogleProvider](provider T) *cli.App {
cfg := config.ProduceConfig(provider)
app := cli.NewApp()
Expand Down Expand Up @@ -77,21 +81,34 @@ func NewOauthProxyApp[T proxycore.KeycloakProvider | proxycore.GoogleProvider](p
// step: create the proxy
proxy, err := ProduceProxy(cfg)
if err != nil {
if errShut := proxy.Shutdown(); errShut != nil {
err = errors.Join(err, errShut)
}
return utils.PrintError(err.Error())
}

// step: start the service
if err := proxy.Run(); err != nil {
var errGroupCtx context.Context
if errGroupCtx, err = proxy.Run(); err != nil {
if errShut := proxy.Shutdown(); errShut != nil {
err = errors.Join(err, errShut)
}
return utils.PrintError(err.Error())
}

// step: setup the termination signals
signalChannel := make(chan os.Signal, 1)
signal.Notify(signalChannel, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
<-signalChannel

if err := proxy.Shutdown(); err != nil {
return utils.PrintError(err.Error())
select {
case <-errGroupCtx.Done():
if err := proxy.Shutdown(); err != nil {
return utils.PrintError(err.Error())
}
case <-signalChannel:
if err := proxy.Shutdown(); err != nil {
return utils.PrintError(err.Error())
}
}

return nil
Expand Down
3 changes: 2 additions & 1 deletion pkg/proxy/core/core.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package core

import (
"context"
"fmt"
"net/http"
"strconv"
Expand Down Expand Up @@ -32,7 +33,7 @@ func GetVersion() string {

type OauthProxies interface {
CreateReverseProxy() error
Run() error
Run() (context.Context, error)
Shutdown() error
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/testsuite/fake_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func newFakeProxy(cfg *config.Config, authConfig *fakeAuthConfig) *fakeProxy {
}

// proxy.log = zap.NewNop()
if err = oProxy.Run(); err != nil {
if _, err = oProxy.Run(); err != nil {
panic("failed to create the proxy service, error: " + err.Error())
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/testsuite/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2527,7 +2527,7 @@ func TestLogRealIP(t *testing.T) {
_ = cfg.Update()

proxy, _ := proxy.NewProxy(cfg, testLog, &FakeUpstreamService{})
_ = proxy.Run()
_, _ = proxy.Run()

cfg.RedirectionURL = "http://" + proxy.Listener.Addr().String()
fp := &fakeProxy{cfg, auth, proxy, make(map[string]*http.Cookie)}
Expand Down
6 changes: 4 additions & 2 deletions pkg/testsuite/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ func TestNewKeycloakProxy(t *testing.T) {
assert.NotNil(t, proxy.Config)
assert.NotNil(t, proxy.Router)
assert.NotNil(t, proxy.Endpoint)
require.NoError(t, proxy.Run())
_, err = proxy.Run()
require.NoError(t, err)
}

func TestNewKeycloakProxyWithLegacyDiscoveryURI(t *testing.T) {
Expand All @@ -78,7 +79,8 @@ func TestNewKeycloakProxyWithLegacyDiscoveryURI(t *testing.T) {
assert.NotNil(t, proxy.Config)
assert.NotNil(t, proxy.Router)
assert.NotNil(t, proxy.Endpoint)
require.NoError(t, proxy.Run())
_, err = proxy.Run()
require.NoError(t, err)
}

func TestReverseProxyHeaders(t *testing.T) {
Expand Down
Loading