diff --git a/go.mod b/go.mod index 0da386c..ad5df06 100644 --- a/go.mod +++ b/go.mod @@ -3,18 +3,22 @@ module github.com/cloudhut/common go 1.19 require ( + github.com/fsnotify/fsnotify v1.6.0 github.com/go-chi/chi/v5 v5.0.8 github.com/google/go-cmp v0.5.9 github.com/prometheus/client_golang v1.15.1 + github.com/stretchr/testify v1.8.0 go.uber.org/zap v1.24.0 ) require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect - github.com/fsnotify/fsnotify v1.6.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.3 // indirect + github.com/kr/text v0.2.0 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.43.0 // indirect github.com/prometheus/procfs v0.9.0 // indirect @@ -22,4 +26,5 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/sys v0.8.0 // indirect google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 3fb5631..67db262 100644 --- a/go.sum +++ b/go.sum @@ -3,7 +3,10 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/go-chi/chi/v5 v5.0.8 h1:lD+NLqFcAi1ovnVZpsnObHGW4xb4J8lNmoYVfECH1Y0= @@ -15,10 +18,14 @@ github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo= github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v1.15.1 h1:8tXpTmJbyH5lydzFPoxSIJ0J46jdh3tylbvM1xCv0LI= github.com/prometheus/client_golang v1.15.1/go.mod h1:e9yaBhRPU2pPNsZwE+JdQl0KEt1N9XgF6zxWmaC0xOk= github.com/prometheus/client_model v0.4.0 h1:5lQXD3cAg1OXBf4Wq03gTrXHeaV0TQvGfUooCfx1yqY= @@ -27,7 +34,12 @@ github.com/prometheus/common v0.43.0 h1:iq+BVjvYLei5f27wiuNiB1DN6DYQkp1c8Bx0Vykh github.com/prometheus/common v0.43.0/go.mod h1:NCvr5cQIh3Y/gy73/RdVtC9r8xxrxwJnB+2lB3BxrFc= github.com/prometheus/procfs v0.9.0 h1:wzCHvIvM5SxWqYvwgVL7yJY8Lz3PKn49KQtpgMYJfhI= github.com/prometheus/procfs v0.9.0/go.mod h1:+pB4zwohETzFnmlpe6yd2lSc+0/46IYZRB/chUwxUZY= +github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= @@ -44,4 +56,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/rest/server.go b/rest/server.go index 4cb567a..1a43a59 100644 --- a/rest/server.go +++ b/rest/server.go @@ -2,7 +2,6 @@ package rest import ( "context" - "crypto/tls" "fmt" "net" "net/http" @@ -12,7 +11,7 @@ import ( "sync" "syscall" - "github.com/fsnotify/fsnotify" + "github.com/cloudhut/common/tls" "github.com/go-chi/chi/v5" "go.uber.org/zap" ) @@ -42,7 +41,7 @@ func NewServer(cfg *Config, logger *zap.Logger, router *chi.Mux) (*Server, error } if cfg.TLS.Enabled { - tlsCfg, err := buildServerTLSConfig(logger, cfg.TLS) + tlsCfg, err := tls.BuildWatchedTLSConfig(logger, cfg.TLS.CertFilepath, cfg.TLS.KeyFilepath, nil) if err != nil { return nil, fmt.Errorf("failed to create TLS config: %w", err) } @@ -109,93 +108,3 @@ func (s *Server) Start() error { return nil } - -func buildServerTLSConfig(logger *zap.Logger, cfg TLSConfig) (*tls.Config, error) { - cert, err := tls.LoadX509KeyPair(cfg.CertFilepath, cfg.KeyFilepath) - if err != nil { - return nil, fmt.Errorf("failed loading TLS cert: %w", err) - } - - signalCh := make(chan os.Signal, 1) - signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM) - - watcher, err := fsnotify.NewWatcher() - if err != nil { - return nil, fmt.Errorf("failed to setup file watcher for hot reloading tls certificates: %w", err) - } - - var lock sync.RWMutex - go func() { - defer watcher.Close() - for { - select { - case event, ok := <-watcher.Events: - if !ok { - return - } - - // Ignore all events that are neither remove nor write - isRelevantEvent := event.Has(fsnotify.Remove) || event.Has(fsnotify.Write) - if !isRelevantEvent { - continue - } - - if event.Has(fsnotify.Remove) { - // Kubernetes uses symbolic links to create the illusion of atomic writes. - // Thus, we have to watch for the remove event and reconfigure our watcher. - // See: https://ahmet.im/blog/kubernetes-inotify/ - _ = watcher.Remove(event.Name) - err = watcher.Add(cfg.CertFilepath) - if err != nil { - logger.Error("failed to re-add file watcher", - zap.String("file_path", cfg.CertFilepath), - zap.Error(err)) - } - - err = watcher.Add(cfg.KeyFilepath) - if err != nil { - logger.Warn("failed to re-add file watcher", - zap.String("file_path", cfg.KeyFilepath), - zap.Error(err)) - } - } - - logger.Info("hot reloading the TLS certificate") - - newCert, err := tls.LoadX509KeyPair(cfg.CertFilepath, cfg.KeyFilepath) - if err != nil { - logger.Error("failed to load certificates", zap.Error(err)) - continue - } - lock.Lock() - cert = newCert - lock.Unlock() - - logger.Info("successfully hot reloaded the TLS certificate") - case err, ok := <-watcher.Errors: - if !ok { - return - } - logger.Error("tls certificate watcher error", zap.Error(err)) - case <-signalCh: - return - } - } - }() - if err := watcher.Add(cfg.CertFilepath); err != nil { - return nil, fmt.Errorf("failed to setup watcher for cert file: %w", err) - } - if err = watcher.Add(cfg.KeyFilepath); err != nil { - return nil, fmt.Errorf("failed to setup watcher for key file: %w", err) - } - - return &tls.Config{ - MinVersion: tls.VersionTLS12, - GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - lock.RLock() - defer lock.RUnlock() - - return &cert, nil - }, - }, nil -} diff --git a/tls/testdata/certs/localhost.crt b/tls/testdata/certs/localhost.crt new file mode 100644 index 0000000..4e81bfa --- /dev/null +++ b/tls/testdata/certs/localhost.crt @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC5TCCAc2gAwIBAgIJAJ6N+ougLqO+MA0GCSqGSIb3DQEBCwUAMBQxEjAQBgNV +BAMMCWxvY2FsaG9zdDAeFw0yMzA2MDUxMzIxMzVaFw0yNDA1MzAxMzIxMzVaMBQx +EjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC +ggEBAMau0ywozyQbvIrf4HlOuBahwHOrJ37Ho7gTDiYfh8/WecipEbQzVFl85iK5 +iFvMPQTeN+ld0FsXzTlDwAU7j2L1rF9FnNJeUUnBsT5Fd0uhWixrUbM0jfBhg+A8 +9uNJ3M084YxmRmuZ/MMSbu3RLMQP8YCJLCWphfDnr6EK04ggIV2C/aTZ+D7eUuQF +aUD3OmfGx0mXFYdegwwQnPdqeOvq8V0//F1KCPllwiOuK3pjbBC/mRJXCbAkjTZs +kq6qeHURCTWscTp3fv/5UBJBVlZxUwI96IJlDWJg4aYrAvuofrry8nE+Tko3h6ZN +2z0Jm/CrV+39fmPUruUVea2DVpUCAwEAAaM6MDgwFAYDVR0RBA0wC4IJbG9jYWxo +b3N0MAsGA1UdDwQEAwIHgDATBgNVHSUEDDAKBggrBgEFBQcDATANBgkqhkiG9w0B +AQsFAAOCAQEArSGyWFWA6fSG9blCDzs/E0HOQayBuyuTnYMcSH1WVCX34eqGr4Kd +oJIjXc1sf9XZ4oRB3NYWKV0HFPnR7EtuY9K1QkcoLlpFvJgYcIu9zW9bl6yVehBt +gDN7uZ6ly9URxbM3yRLTT+Hy5tO+AoOkKvfjf5dp7ieIRPyyn4AjpLuJID9BqP/s +L7bATFt7RiYp4BTmKJ19R7X2lABzGa2JvyHV0Y55JuYaABsRl7rarIPu6PFfVnvs +z/l5vB95FBPIQpBcxXA77hLzwj1RczFXzlI8uCyXn3EfqYn7UaoOqw6/5W3F8g+Y +VwwxpeIhldfeahJNAah+1V6tO26hr8gJQw== +-----END CERTIFICATE----- diff --git a/tls/testdata/certs/localhost.key b/tls/testdata/certs/localhost.key new file mode 100644 index 0000000..f5172ce --- /dev/null +++ b/tls/testdata/certs/localhost.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDGrtMsKM8kG7yK +3+B5TrgWocBzqyd+x6O4Ew4mH4fP1nnIqRG0M1RZfOYiuYhbzD0E3jfpXdBbF805 +Q8AFO49i9axfRZzSXlFJwbE+RXdLoVosa1GzNI3wYYPgPPbjSdzNPOGMZkZrmfzD +Em7t0SzED/GAiSwlqYXw56+hCtOIICFdgv2k2fg+3lLkBWlA9zpnxsdJlxWHXoMM +EJz3anjr6vFdP/xdSgj5ZcIjrit6Y2wQv5kSVwmwJI02bJKuqnh1EQk1rHE6d37/ ++VASQVZWcVMCPeiCZQ1iYOGmKwL7qH668vJxPk5KN4emTds9CZvwq1ft/X5j1K7l +FXmtg1aVAgMBAAECggEBAKHPM9CdE8Y2iKEZn3lsMOTNqy0I0UuhT6bUbguCVltg +MyLG/tIhk6ql28+gBnuspG1YhXSboNrvUYY3tSUN0sMnjdCxovx5L/6/rpgmfver +WwMeDBXE0WxaHsr7G58UQq0rzg1IJkXvzTkZxBoO50RuL6MdFEVAAQOnzRN8+7W5 +9dX0V2KKYWm8DoouxpgCFkpfCviwgaHUQWha85zbLgGRKN7lVcM5uUFMcl5ypxGz +g3+s1TBw+g6Al/o5x1QSTrhydvnAKsWgUri9GHBSt0NoA4BSNSHLmhv5q7abJE7L +St+TvbghhsbKGReP2zgikQ7UVtZGqhtf99irP4Gu4OkCgYEA400BbglVF63mAwjs +ge9VHSHAOZcFtBzwWt4+nC3wYEE1sutnLuuHEh1AR18RmXWsnQ3I//i2S6KXoaKp +qKBTDa/N0sLEN5mT9pLhnzNJF0Seo4MHkCig2i+zHW7VlLJJD/9/kedwlmtmhfTe +PlVI1fTbfhDVqqLchEHo6AxYGGcCgYEA38TPMp8RyM392HlGPCI2ngY8yLtxkdQT +ONHB8fA/LIeaTjlzU/NI0/J/sNZy+B8LEKYMA62m3qbAi7army+IrmogWJaTrDRa +FlXSUg9F31MOhTX0T8tBAskkH11RD879wHejFbwylLsZQmDS+F8gG/6bNT2pDtgz +2Kyn8IaJq6MCgYEArlYv1I//3guZMZa0n+xLYe6zGvjEfSL9DxUK/IsXpRwe7b40 +A/7OOIyK8rLuMr/YxxT9p6bBWz24A1dZvWZKjWLcAN011ldK74I03wBc/SW6bzte +n6kpxm9zeA28bzJXa5fR5ryW1ChIGFJ562FKXiBSAV00JI6JiD9tPh3Jq90CgYAZ +0PH6rCF4IlPcCrnQrD3S43NV0VJb+bSyBHk0uXwAXjCuP7CPiezoDv0uYL9o4uP6 +6r1OG1W6MFDcjZmk0MobHUFYFx84ad3O393g+8Qa7NErCzuBjTiV4rDZMYHtqfra +nrLhChJn2GIkp1kPsKHauPgdH10GymjI4bqKZGszswKBgQCB2eT638rKIJobA2MA +kCYwhMldehXIlmBGy/QLhDEdJ2qridOb79pDbnKTd+vyBEmU+V1W/UK1kG/NtA9S +yXkWIRWxzC8qf9/D9tifnybCLIpdn62YBGLcAEJKCsP1o0U9VR8v2HjkDRghjUkI +vvlr90zAjfyTcXmG3kWplyeglQ== +-----END PRIVATE KEY----- diff --git a/tls/testdata/certs/localhost2.crt b/tls/testdata/certs/localhost2.crt new file mode 100644 index 0000000..2a4a24f --- /dev/null +++ b/tls/testdata/certs/localhost2.crt @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC5TCCAc2gAwIBAgIJAK/t3663m69zMA0GCSqGSIb3DQEBCwUAMBQxEjAQBgNV +BAMMCWxvY2FsaG9zdDAeFw0yMzA2MDUxMzIxMzVaFw0yNDA1MjAxMzIxMzVaMBQx +EjAQBgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC +ggEBAMHf7WSZoHsDZ1UBVImL/cw77N+yJ2kVufaYz7Uc3knuVProXXw98RhUG5cl +JXiAgw35+QF4+TyEzeifF6RplrX90AGStmwgDUBveCjqhC5QPVsJ8tbIXX0X50DA +dNDo/2mT3Kb4Tlzhq5D7cVRrJXrCdB3WNX5cNwiYN1MpbnXbydARlHlcdzs9u0yk +BFoiYajmkseYtcBAr91CxCAn8Mxo2Un0CVaVW5VGIIbA4qh0i8cnzW4GLfpCIxsQ +dIbt4wk3j10v2eLq3U+dtHWbTQQyMerB4hDwb95AyUhjdvO4xNXEyHki69LBwV8a ++4+nJddvLw9hAfdWWyiEz7B0me0CAwEAAaM6MDgwFAYDVR0RBA0wC4IJbG9jYWxo +b3N0MAsGA1UdDwQEAwIHgDATBgNVHSUEDDAKBggrBgEFBQcDATANBgkqhkiG9w0B +AQsFAAOCAQEAO7aNbohyfAU3+8U5ODF38GUJICxwpGzJF8okUSaxa8C6A28NsPep +3GOaoVi39yekq1YrqVfCGnVFqisYV/VOBdNZVORHtJpB40700IKrMaBmERkiQpT1 +VxzgSr1piXPVXHFJZHnrnvA1hBXxSSSXguNZRavFrj0i89OQouHy13jFSTFr91K3 +ktajwLt4YEKelsEqTurQcjX75i+Q0YlsD+UjOHp9F4P82mgFPtE0WWeU05+5n1G4 +/T5xKUH9NrSJdcDgFSV4yPNfdsPdJfA5Ohfks/NJoo1F84MajPvksAoitOUcSlgc +QA8a9lT9fnAC3piDSXQybCFZV9cNVx4E1A== +-----END CERTIFICATE----- diff --git a/tls/testdata/certs/localhost2.key b/tls/testdata/certs/localhost2.key new file mode 100644 index 0000000..976f47a --- /dev/null +++ b/tls/testdata/certs/localhost2.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDB3+1kmaB7A2dV +AVSJi/3MO+zfsidpFbn2mM+1HN5J7lT66F18PfEYVBuXJSV4gIMN+fkBePk8hM3o +nxekaZa1/dABkrZsIA1Ab3go6oQuUD1bCfLWyF19F+dAwHTQ6P9pk9ym+E5c4auQ ++3FUayV6wnQd1jV+XDcImDdTKW5128nQEZR5XHc7PbtMpARaImGo5pLHmLXAQK/d +QsQgJ/DMaNlJ9AlWlVuVRiCGwOKodIvHJ81uBi36QiMbEHSG7eMJN49dL9ni6t1P +nbR1m00EMjHqweIQ8G/eQMlIY3bzuMTVxMh5IuvSwcFfGvuPpyXXby8PYQH3Vlso +hM+wdJntAgMBAAECggEBAIWV6vfnVwmL1dZfnUVNPWpNXDDii38/5iwBLRVJN+1P +GCTumQOzln1B7uTdRo1aV3L469dU6L8Hbu27OUojKyJpKbr7wVCNYTQl2nCu7rcO +uMgS+c1+r9Qy9TfLpHISKXMw29f8vdoH8PRsHLGjRmbot6ObZq6TkaQNZgmaQa9Y +tsbXRK7L2hT4gNoNNlonTnD/7xw2eqJU5btmEZmQ0FVeizvxjbpM0/QAYjffC50o +qKokUz0ARzuUC84ke0K5HbNYyygDARkekrLAJIIJFHNBjsDrUd1Hd8F1tveI8fUc +jHfvUszdX08I6Lg4tFQ0ESSUcOPbM9gOmMRKUGE+v20CgYEA9HTW5/Qbu8Icr5YO +uuQPkWvtcG6oGmoag7NUcV3WXKoAFLrV2Yi93XEY3JtRoX7kFGGYTId7aqFYaFnq +D8uR/NWJwzzSgA/3CTJ6Bx0HsIO84hjBbddFOMy79koPVHXV9gZyFVZDoIuu5PRX +MPQJfLAlnxxgdHbKRMp5J/4uzQsCgYEAywee86rgLrFyyLOBJG8nR/2f2cmjcJHg +NyBM3fXgewCQu0r/xKW6y1U3560XAxiFfIADP2F8L7xgPulmJtRnKU3z0cHHwmsc +/MJvO8rvrMgU7O/fNLX5N1i8uZDrxmQcf8S0u0PAicQTiXorycOB1EH5Zf8CQVMo +ZxoPpG733+cCgYBgEUafcywu9lLFoif5xERl9s8h3yrK7qWq2h+2SZVDZz+O5fnC +el17F8YYdCV5XN+PLudmM9wJhIy0vZkhSfP+M4DnLBDhaOTBRYf1IbBy6uKgy+/A +FdhLQRIg8OvjWkeSXugYgIUlI5/AtFFLmKvdx2+RftpdCo3kyNkiIV8NDwKBgAQa +B1AM57KJyzPazIUb6cM+kHgp5q9jgxAaCvOBACP8AvCFt10VrAxnkFWR3aEmYav+ +OhKRuZyNRbR/qpymNd9Tv9VBAPQgjdldZDnlA6qN8D5JKk06T+qaVFW7Y8gCRcEf +DDesSrt9xpdEbJYK6RiMrKku2bDQKUTL9fzwcPmJAoGAbUEi8TemjMGkThiu9do+ +yuZx7DlBG/+NrE/xkHO7RSK6IxX6KorYp2vqGSmz6dm6zp1GwoFOqC+ArD40de4V +gDHZV9mI1IllXAlbTY05qiIeJlyTvert3TTW/PGDrqnr3+EDdx4fecU+FdJ0b3Yk +7F3iknvaFE055+75cb/fm8k= +-----END PRIVATE KEY----- diff --git a/tls/tls.go b/tls/tls.go new file mode 100644 index 0000000..480dd7c --- /dev/null +++ b/tls/tls.go @@ -0,0 +1,109 @@ +package tls + +import ( + "crypto/tls" + "fmt" + "os" + "os/signal" + "sync" + "syscall" + + "github.com/fsnotify/fsnotify" + "go.uber.org/zap" +) + +// BuildWatchedTLSConfig builds TLS with integrated watched to reload certificate files when they are changed. +func BuildWatchedTLSConfig(logger *zap.Logger, certFile, keyFile string, notify chan<- *tls.Certificate) (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, fmt.Errorf("failed loading TLS cert: %w", err) + } + + signalCh := make(chan os.Signal, 1) + signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM) + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("failed to setup file watcher for hot reloading tls certificates: %w", err) + } + + var lock sync.RWMutex + go func() { + defer watcher.Close() + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + + // Ignore all events that are neither remove nor write + isRelevantEvent := event.Has(fsnotify.Remove) || event.Has(fsnotify.Write) + if !isRelevantEvent { + continue + } + + if event.Has(fsnotify.Remove) { + // Kubernetes uses symbolic links to create the illusion of atomic writes. + // Thus, we have to watch for the remove event and reconfigure our watcher. + // See: https://ahmet.im/blog/kubernetes-inotify/ + _ = watcher.Remove(event.Name) + err = watcher.Add(certFile) + if err != nil { + logger.Error("failed to re-add file watcher", + zap.String("file_path", certFile), + zap.Error(err)) + } + + err = watcher.Add(keyFile) + if err != nil { + logger.Warn("failed to re-add file watcher", + zap.String("file_path", keyFile), + zap.Error(err)) + } + } + + logger.Info("hot reloading the TLS certificate") + + newCert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + logger.Error("failed to load certificates", zap.Error(err)) + continue + } + + lock.Lock() + cert = newCert + lock.Unlock() + + if notify != nil { + notify <- &newCert + } + + logger.Info("successfully hot reloaded the TLS certificate") + case err, ok := <-watcher.Errors: + if !ok { + return + } + logger.Error("tls certificate watcher error", zap.Error(err)) + case <-signalCh: + return + } + } + }() + if err := watcher.Add(certFile); err != nil { + return nil, fmt.Errorf("failed to setup watcher for cert file: %w", err) + } + if err = watcher.Add(keyFile); err != nil { + return nil, fmt.Errorf("failed to setup watcher for key file: %w", err) + } + + return &tls.Config{ + MinVersion: tls.VersionTLS12, + GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + lock.RLock() + defer lock.RUnlock() + + return &cert, nil + }, + }, nil +} diff --git a/tls/tls_test.go b/tls/tls_test.go new file mode 100644 index 0000000..651ebdc --- /dev/null +++ b/tls/tls_test.go @@ -0,0 +1,283 @@ +package tls + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "net/http" + "os" + "path/filepath" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "go.uber.org/zap" +) + +func TestBuildWatchedTLSConfig(t *testing.T) { + log := zap.NewExample() + + t.Run("no update", func(t *testing.T) { + l, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + port := l.Addr().(*net.TCPAddr).Port + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + }) + + tlsServerConfig, err := BuildWatchedTLSConfig(log, + "./testdata/certs/localhost.crt", "./testdata/certs/localhost.key", + nil) + require.NoError(t, err) + + ts := &http.Server{ + Handler: handler, + TLSConfig: tlsServerConfig, + } + + go func() { + ts.ServeTLS(l, "", "") + }() + + timer1 := time.NewTimer(10 * time.Millisecond) + <-timer1.C + + t.Cleanup(func() { + ts.Shutdown(context.Background()) + }) + + tlsConfig := newTLSConfig(t, "./testdata/certs/localhost.crt", "./testdata/certs/localhost.key") + tlsConfig.InsecureSkipVerify = true + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + + res, err := client.Get("https://localhost:" + strconv.Itoa(port) + "/") + assert.NoError(t, err) + assert.NotNil(t, res) + assert.Equal(t, "HTTP/1.1", res.Proto) + + pc := res.TLS.PeerCertificates[0] + dur := pc.NotAfter.Sub(pc.NotBefore) + assert.Equal(t, 360*24.0, dur.Hours()) + }) + + t.Run("update no channel", func(t *testing.T) { + l, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + port := l.Addr().(*net.TCPAddr).Port + + // temp directory + dname, err := os.MkdirTemp("", "certs") + require.NoError(t, err) + + t.Cleanup(func() { + os.RemoveAll(dname) + }) + + // copy first set of certs to temp dir + certFile := filepath.Join(dname, "secure.crt") + keyFile := filepath.Join(dname, "secure.key") + + copyFile(t, "./testdata/certs/localhost.crt", certFile) + copyFile(t, "./testdata/certs/localhost.key", keyFile) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + }) + + tlsServerConfig, err := BuildWatchedTLSConfig(log, certFile, keyFile, nil) + require.NoError(t, err) + + ts := &http.Server{ + Handler: handler, + TLSConfig: tlsServerConfig, + } + + go func() { + ts.ServeTLS(l, "", "") + }() + + timer1 := time.NewTimer(10 * time.Millisecond) + <-timer1.C + + t.Cleanup(func() { + ts.Shutdown(context.Background()) + }) + + tlsConfig := newTLSConfig(t, "./testdata/certs/localhost.crt", "./testdata/certs/localhost.key") + tlsConfig.InsecureSkipVerify = true + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + + res, err := client.Get("https://localhost:" + strconv.Itoa(port) + "/") + assert.NoError(t, err) + assert.NotNil(t, res) + assert.Equal(t, "HTTP/1.1", res.Proto) + + pc := res.TLS.PeerCertificates[0] + dur := pc.NotAfter.Sub(pc.NotBefore) + assert.Equal(t, 360*24.0, dur.Hours()) + + // update certs + + timer1 = time.NewTimer(10 * time.Millisecond) + <-timer1.C + + copyFile(t, "./testdata/certs/localhost2.crt", certFile) + copyFile(t, "./testdata/certs/localhost2.key", keyFile) + + // allow for hot reload + timer1 = time.NewTimer(200 * time.Millisecond) + <-timer1.C + + // check + res, err = client.Get("https://localhost:" + strconv.Itoa(port) + "/") + assert.NoError(t, err) + assert.NotNil(t, res) + assert.Equal(t, "HTTP/1.1", res.Proto) + + pc = res.TLS.PeerCertificates[0] + dur = pc.NotAfter.Sub(pc.NotBefore) + assert.Equal(t, 350*24.0, dur.Hours()) + }) + + t.Run("update with channel", func(t *testing.T) { + l, err := net.Listen("tcp", ":0") + require.NoError(t, err) + + port := l.Addr().(*net.TCPAddr).Port + + // temp directory + dname, err := os.MkdirTemp("", "certs") + require.NoError(t, err) + + t.Cleanup(func() { + os.RemoveAll(dname) + }) + + // copy first set of certs to temp dir + certFile := filepath.Join(dname, "secure.crt") + keyFile := filepath.Join(dname, "secure.key") + + copyFile(t, "./testdata/certs/localhost.crt", certFile) + copyFile(t, "./testdata/certs/localhost.key", keyFile) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + }) + + signalCh := make(chan *tls.Certificate, 10) + + count := atomic.Int32{} + go func() { + update := <-signalCh + assert.NotEmpty(t, update) + count.Add(1) + }() + + tlsServerConfig, err := BuildWatchedTLSConfig(log, certFile, keyFile, signalCh) + require.NoError(t, err) + + ts := &http.Server{ + Handler: handler, + TLSConfig: tlsServerConfig, + } + + go func() { + ts.ServeTLS(l, "", "") + }() + + timer1 := time.NewTimer(10 * time.Millisecond) + <-timer1.C + + t.Cleanup(func() { + ts.Shutdown(context.Background()) + }) + + tlsConfig := newTLSConfig(t, "./testdata/certs/localhost.crt", "./testdata/certs/localhost.key") + tlsConfig.InsecureSkipVerify = true + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfig, + }, + } + + res, err := client.Get("https://localhost:" + strconv.Itoa(port) + "/") + assert.NoError(t, err) + assert.NotNil(t, res) + assert.Equal(t, "HTTP/1.1", res.Proto) + + pc := res.TLS.PeerCertificates[0] + dur := pc.NotAfter.Sub(pc.NotBefore) + assert.Equal(t, 360*24.0, dur.Hours()) + + // update certs + + timer1 = time.NewTimer(10 * time.Millisecond) + <-timer1.C + + copyFile(t, "./testdata/certs/localhost2.crt", certFile) + copyFile(t, "./testdata/certs/localhost2.key", keyFile) + + // allow for hot reload + timer1 = time.NewTimer(200 * time.Millisecond) + <-timer1.C + + // check + res, err = client.Get("https://localhost:" + strconv.Itoa(port) + "/") + assert.NoError(t, err) + assert.NotNil(t, res) + assert.Equal(t, "HTTP/1.1", res.Proto) + + pc = res.TLS.PeerCertificates[0] + dur = pc.NotAfter.Sub(pc.NotBefore) + assert.Equal(t, 350*24.0, dur.Hours()) + + v := count.Load() + + assert.Equal(t, int32(1), v) + + close(signalCh) + }) +} + +func newTLSConfig(t *testing.T, certFile, keyFile string) *tls.Config { + t.Helper() + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + require.NoError(t, err) + + caCertPool, err := x509.SystemCertPool() + require.NoError(t, err) + + return &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{cert}, + RootCAs: caCertPool, + } +} + +func copyFile(t *testing.T, src, dst string) { + t.Helper() + + dat, err := os.ReadFile(src) + require.NoError(t, err) + + err = os.WriteFile(dst, dat, 0644) + require.NoError(t, err) +}