From 53e03169157d1c1c1374630f8cb589fbd6439a4b Mon Sep 17 00:00:00 2001 From: Krishna Iyer Easwaran Date: Mon, 4 Mar 2024 14:11:29 +0100 Subject: [PATCH] dev: Use system certs if not insecure --- pkg/source/chirpstack/config/config.go | 46 ++++++++++++++------------ 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/pkg/source/chirpstack/config/config.go b/pkg/source/chirpstack/config/config.go index 503e0d7..5d082c0 100644 --- a/pkg/source/chirpstack/config/config.go +++ b/pkg/source/chirpstack/config/config.go @@ -105,13 +105,6 @@ func (c *Config) Initialize() error { if err := c.JoinEUI.UnmarshalText([]byte(c.joinEUI)); err != nil { return errInvalidJoinEUI.WithAttributes("join_eui", c.joinEUI) } - - if !c.insecure && c.caPath != "" { - if err := setCustomCA(c.caPath); err != nil { - return err - } - } - err := c.dialGRPC( grpc.FailOnNonTempDialError(true), grpc.WithBlock(), @@ -125,11 +118,14 @@ func (c *Config) Initialize() error { } func (c *Config) dialGRPC(opts ...grpc.DialOption) error { - if c.insecure || c.caPath == "" { + if c.insecure { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } - if tls := http.DefaultTransport.(*http.Transport).TLSClientConfig; tls != nil { - opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tls))) + } else { + tlsConfig, err := generateTLSConfig(c.caPath) + if err != nil { + return err + } + opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) } ctx, cancel := context.WithTimeout(context.Background(), dialTimeout) @@ -143,18 +139,24 @@ func (c *Config) dialGRPC(opts ...grpc.DialOption) error { return nil } -func setCustomCA(path string) error { - pemBytes, err := os.ReadFile(path) - if err != nil { - return err +// GenerateTLSConfig generates a TLS configuration. +func generateTLSConfig(caPath string) (cfg *tls.Config, err error) { + cfg = http.DefaultTransport.(*http.Transport).TLSClientConfig + if cfg == nil { + cfg = &tls.Config{} } - rootCAs := http.DefaultTransport.(*http.Transport).TLSClientConfig.RootCAs - if rootCAs == nil { - if rootCAs, err = x509.SystemCertPool(); err != nil { - rootCAs = x509.NewCertPool() + if cfg.RootCAs == nil { + if cfg.RootCAs, err = x509.SystemCertPool(); err != nil { + cfg.RootCAs = x509.NewCertPool() } } - rootCAs.AppendCertsFromPEM(pemBytes) - http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{RootCAs: rootCAs} - return nil + if caPath == "" { + return cfg, nil + } + pemBytes, err := os.ReadFile(caPath) + if err != nil { + return nil, err + } + cfg.RootCAs.AppendCertsFromPEM(pemBytes) + return cfg, nil }