From bb058b4082947dd5e1f9e81613b2d606daccfa66 Mon Sep 17 00:00:00 2001 From: Scott Leggett Date: Sun, 22 Sep 2024 06:22:48 +0200 Subject: [PATCH] feat: avoid deprecated NATS API Manually marshal JSON as per recommendation from upstream. --- cmd/ssh-portal/serve.go | 16 ++++++++------ internal/sshportalapi/server.go | 15 +++++++------ internal/sshportalapi/sshportal.go | 28 ++++++++++++++++++------- internal/sshportalapi/sshportal_test.go | 27 ++++++++++++++++++++++++ internal/sshserver/authhandler.go | 22 ++++++++++++++----- internal/sshserver/serve.go | 2 +- 6 files changed, 83 insertions(+), 27 deletions(-) create mode 100644 internal/sshportalapi/sshportal_test.go diff --git a/cmd/ssh-portal/serve.go b/cmd/ssh-portal/serve.go index 5c46625c..dd498d28 100644 --- a/cmd/ssh-portal/serve.go +++ b/cmd/ssh-portal/serve.go @@ -36,7 +36,7 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error { ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM) defer stop() // get nats server connection - nconn, err := nats.Connect(cmd.NATSServer, + nc, err := nats.Connect(cmd.NATSServer, nats.Name("ssh-portal"), // exit on connection close nats.ClosedHandler(func(_ *nats.Conn) { @@ -52,10 +52,6 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error { if err != nil { return fmt.Errorf("couldn't connect to NATS server: %v", err) } - nc, err := nats.NewEncodedConn(nconn, "json") - if err != nil { - return fmt.Errorf("couldn't get encoded conn: %v", err) - } defer nc.Close() // start listening on TCP port l, err := net.Listen("tcp", fmt.Sprintf(":%d", cmd.SSHServerPort)) @@ -83,7 +79,15 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error { eg.Go(func() error { // start serving SSH connection requests return sshserver.Serve( - ctx, log, nc, l, c, hostkeys, cmd.LogAccessEnabled, cmd.Banner) + ctx, + log, + nc, + l, + c, + hostkeys, + cmd.LogAccessEnabled, + cmd.Banner, + ) }) return eg.Wait() } diff --git a/internal/sshportalapi/server.go b/internal/sshportalapi/server.go index 826a7aec..d1dd6da2 100644 --- a/internal/sshportalapi/server.go +++ b/internal/sshportalapi/server.go @@ -50,7 +50,7 @@ func ServeNATS( wg := sync.WaitGroup{} wg.Add(1) // connect to NATS server - nconn, err := nats.Connect(natsURL, + nc, err := nats.Connect(natsURL, nats.Name("ssh-portal-api"), // synchronise exiting ServeNATS() nats.ClosedHandler(func(_ *nats.Conn) { @@ -67,14 +67,13 @@ func ServeNATS( if err != nil { return fmt.Errorf("couldn't connect to NATS server: %v", err) } - nc, err := nats.NewEncodedConn(nconn, "json") - if err != nil { - return fmt.Errorf("couldn't get encoded conn: %v", err) - } defer nc.Close() - // set up request/response callback for sshportal - _, err = nc.QueueSubscribe(bus.SubjectSSHAccessQuery, queue, - sshportal(ctx, log, nc, p, l, k)) + // configure callback + _, err = nc.QueueSubscribe( + bus.SubjectSSHAccessQuery, + queue, + sshportal(ctx, log, nc, p, l, k), + ) if err != nil { return fmt.Errorf("couldn't subscribe to queue: %v", err) } diff --git a/internal/sshportalapi/sshportal.go b/internal/sshportalapi/sshportal.go index ef1be73a..583a3f41 100644 --- a/internal/sshportalapi/sshportal.go +++ b/internal/sshportalapi/sshportal.go @@ -2,6 +2,7 @@ package sshportalapi import ( "context" + "encoding/json" "errors" "log/slog" "time" @@ -23,20 +24,30 @@ var ( }) ) +var ( + falseResponse = []byte(`false`) + trueResponse = []byte(`true`) +) + func sshportal( ctx context.Context, log *slog.Logger, - c *nats.EncodedConn, + c *nats.Conn, p *rbac.Permission, l LagoonDBService, k KeycloakService, -) nats.Handler { - return func(_, replySubject string, query *bus.SSHAccessQuery) { +) nats.MsgHandler { + return func(msg *nats.Msg) { var realmRoles, userGroups []string // set up tracing and update metrics ctx, span := otel.Tracer(pkgName).Start(ctx, bus.SubjectSSHAccessQuery) defer span.End() requestsCounter.Inc() + var query bus.SSHAccessQuery + if err := json.Unmarshal(msg.Data, &query); err != nil { + log.Warn("couldn't unmarshal query", slog.Any("query", msg.Data)) + return + } log := log.With(slog.Any("query", query)) // sanity check the query if query.SSHFingerprint == "" || query.NamespaceName == "" { @@ -48,7 +59,7 @@ func sshportal( if err != nil { if errors.Is(err, lagoondb.ErrNoResult) { log.Warn("unknown namespace name", slog.Any("error", err)) - if err = c.Publish(replySubject, false); err != nil { + if err = c.Publish(msg.Reply, falseResponse); err != nil { log.Error("couldn't publish reply", slog.Any("error", err)) } return @@ -65,7 +76,7 @@ func sshportal( log.Warn("ID mismatch in environment identification", slog.Any("env", env), slog.Any("error", err)) - if err = c.Publish(replySubject, false); err != nil { + if err = c.Publish(msg.Reply, falseResponse); err != nil { log.Error("couldn't publish reply", slog.Any("error", err)) } return @@ -75,7 +86,7 @@ func sshportal( if err != nil { if errors.Is(err, lagoondb.ErrNoResult) { log.Debug("unknown SSH Fingerprint", slog.Any("error", err)) - if err = c.Publish(replySubject, false); err != nil { + if err = c.Publish(msg.Reply, falseResponse); err != nil { log.Error("couldn't publish reply", slog.Any("error", err)) } return @@ -115,10 +126,13 @@ func sshportal( ok := p.UserCanSSHToEnvironment( ctx, env, realmRoles, userGroups, groupNameProjectIDsMap) var logMsg string + var response []byte if ok { logMsg = "SSH access authorized" + response = trueResponse } else { logMsg = "SSH access not authorized" + response = falseResponse } log.Info(logMsg, slog.Int("environmentID", env.ID), @@ -127,7 +141,7 @@ func sshportal( slog.String("projectName", env.ProjectName), slog.String("userUUID", user.UUID.String()), ) - if err = c.Publish(replySubject, ok); err != nil { + if err = c.Publish(msg.Reply, response); err != nil { log.Error("couldn't publish reply", slog.String("userUUID", user.UUID.String()), slog.Any("error", err)) diff --git a/internal/sshportalapi/sshportal_test.go b/internal/sshportalapi/sshportal_test.go new file mode 100644 index 00000000..5bfff60f --- /dev/null +++ b/internal/sshportalapi/sshportal_test.go @@ -0,0 +1,27 @@ +package sshportalapi + +import ( + "encoding/json" + "testing" +) + +func TestResponseMarshal(t *testing.T) { + var testCases = map[string]struct { + input []byte + expect bool + }{ + "true": {input: trueResponse, expect: true}, + "false": {input: falseResponse, expect: false}, + } + for name, tc := range testCases { + t.Run(name, func(tt *testing.T) { + var value bool + if err := json.Unmarshal(tc.input, &value); err != nil { + tt.Fatalf("error unmarshaling data %v to bool", tc.input) + } + if value != tc.expect { + tt.Fatalf("expected %v, got %v", tc.expect, value) + } + }) + } +} diff --git a/internal/sshserver/authhandler.go b/internal/sshserver/authhandler.go index f1e35dcc..0694cfbb 100644 --- a/internal/sshserver/authhandler.go +++ b/internal/sshserver/authhandler.go @@ -1,6 +1,7 @@ package sshserver import ( + "encoding/json" "log/slog" "time" @@ -40,8 +41,11 @@ var ( // pubKeyAuth returns a ssh.PublicKeyHandler which queries the remote // ssh-portal-api for Lagoon SSH authorization. -func pubKeyAuth(log *slog.Logger, nc *nats.EncodedConn, - c *k8s.Client) ssh.PublicKeyHandler { +func pubKeyAuth( + log *slog.Logger, + nc *nats.Conn, + c *k8s.Client, +) ssh.PublicKeyHandler { return func(ctx ssh.Context, key ssh.PublicKey) bool { authAttemptsTotal.Inc() log := log.With(slog.String("sessionID", ctx.SessionID())) @@ -60,21 +64,29 @@ func pubKeyAuth(log *slog.Logger, nc *nats.EncodedConn, } // construct ssh access query fingerprint := gossh.FingerprintSHA256(pubKey) - q := bus.SSHAccessQuery{ + queryData, err := json.Marshal(bus.SSHAccessQuery{ SSHFingerprint: fingerprint, NamespaceName: ctx.User(), ProjectID: pid, EnvironmentID: eid, SessionID: ctx.SessionID(), + }) + if err != nil { + log.Warn("couldn't marshal NATS request", slog.Any("error", err)) + return false } // send query - var ok bool - err = nc.Request(bus.SubjectSSHAccessQuery, q, &ok, natsTimeout) + msg, err := nc.Request(bus.SubjectSSHAccessQuery, queryData, natsTimeout) if err != nil { log.Warn("couldn't make NATS request", slog.Any("error", err)) return false } // handle response + var ok bool + if err := json.Unmarshal(msg.Data, &ok); err != nil { + log.Warn("couldn't unmarshal response", slog.Any("response", msg.Data)) + return false + } if !ok { log.Debug("SSH access not authorized", slog.String("fingerprint", fingerprint), diff --git a/internal/sshserver/serve.go b/internal/sshserver/serve.go index 3a868a20..137baddf 100644 --- a/internal/sshserver/serve.go +++ b/internal/sshserver/serve.go @@ -40,7 +40,7 @@ func disableSHA1Kex(_ ssh.Context) *gossh.ServerConfig { func Serve( ctx context.Context, log *slog.Logger, - nc *nats.EncodedConn, + nc *nats.Conn, l net.Listener, c *k8s.Client, hostKeys [][]byte,