From 9ec7c54baa42e694f8e4b36a1a046b89d278ffe7 Mon Sep 17 00:00:00 2001 From: bjoern-m Date: Thu, 12 Dec 2024 17:21:08 +0100 Subject: [PATCH] feat: enhance session response --- backend/dto/session.go | 68 +++++++++++++++++++++- backend/handler/session.go | 114 +++++++++++++++---------------------- 2 files changed, 111 insertions(+), 71 deletions(-) diff --git a/backend/dto/session.go b/backend/dto/session.go index 60ed33e14..fb68d37ad 100644 --- a/backend/dto/session.go +++ b/backend/dto/session.go @@ -1,8 +1,10 @@ package dto import ( + "encoding/json" "fmt" "github.com/gofrs/uuid" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/mileusna/useragent" "github.com/teamhanko/hanko/backend/persistence/models" "time" @@ -33,10 +35,72 @@ func FromSessionModel(model models.Session, current bool) SessionData { } } +type Claims struct { + Subject uuid.UUID `json:"subject"` + IssuedAt *time.Time `json:"issued_at,omitempty"` + Expiration time.Time `json:"expiration"` + Audience []string `json:"audience,omitempty"` + Issuer *string `json:"issuer,omitempty"` + Email *EmailJwt `json:"email,omitempty"` + SessionID uuid.UUID `json:"session_id"` +} + type ValidateSessionResponse struct { - IsValid bool `json:"is_valid"` + IsValid bool `json:"is_valid"` + Claims *Claims `json:"claims,omitempty"` + // deprecated ExpirationTime *time.Time `json:"expiration_time,omitempty"` - UserID *uuid.UUID `json:"user_id,omitempty"` + // deprecated + UserID *uuid.UUID `json:"user_id,omitempty"` +} + +func GetClaimsFromToken(token jwt.Token) (*Claims, error) { + claims := &Claims{} + + if subject := token.Subject(); len(subject) > 0 { + s, err := uuid.FromString(subject) + if err != nil { + return nil, fmt.Errorf("'subject' is not a uuid: %w", err) + } + claims.Subject = s + } + + if sessionID, valid := token.Get("session_id"); valid { + s, err := uuid.FromString(sessionID.(string)) + if err != nil { + return nil, fmt.Errorf("'session_id' is not a uuid: %w", err) + } + claims.SessionID = s + } + + if issuedAt := token.IssuedAt(); !issuedAt.IsZero() { + claims.IssuedAt = &issuedAt + } + + if audience := token.Audience(); len(audience) > 0 { + claims.Audience = audience + } + + if issuer := token.Issuer(); len(issuer) > 0 { + claims.Issuer = &issuer + } + + if email, valid := token.Get("email"); valid { + if data, ok := email.(map[string]interface{}); ok { + jsonData, err := json.Marshal(data) + if err != nil { + return nil, fmt.Errorf("failed to marshal 'email' claim: %w", err) + } + err = json.Unmarshal(jsonData, &claims.Email) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal 'email' claim: %w", err) + } + } + } + + claims.Expiration = token.Expiration() + + return claims, nil } type ValidateSessionRequest struct { diff --git a/backend/handler/session.go b/backend/handler/session.go index c27e41d47..dee875e54 100644 --- a/backend/handler/session.go +++ b/backend/handler/session.go @@ -2,10 +2,8 @@ package handler import ( "fmt" - "github.com/gofrs/uuid" echojwt "github.com/labstack/echo-jwt/v4" "github.com/labstack/echo/v4" - "github.com/lestrrat-go/jwx/v2/jwt" "github.com/teamhanko/hanko/backend/config" "github.com/teamhanko/hanko/backend/dto" "github.com/teamhanko/hanko/backend/persistence" @@ -36,61 +34,47 @@ func (h *SessionHandler) ValidateSession(c echo.Context) error { return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) } - var token jwt.Token for _, extractor := range extractors { auths, extractorErr := extractor(c) if extractorErr != nil { continue } for _, auth := range auths { - t, tokenErr := h.sessionManager.Verify(auth) + token, tokenErr := h.sessionManager.Verify(auth) if tokenErr != nil { continue } - if h.cfg.Session.ServerSide.Enabled { - // check that the session id is stored in the database - sessionId, ok := t.Get("session_id") - if !ok { - continue - } - sessionID, err := uuid.FromString(sessionId.(string)) - if err != nil { - continue - } - - sessionModel, err := h.persister.GetSessionPersister().Get(sessionID) - if err != nil { - return fmt.Errorf("failed to get session from database: %w", err) - } - if sessionModel == nil { - continue - } - - // Update lastUsed field - sessionModel.LastUsed = time.Now().UTC() - err = h.persister.GetSessionPersister().Update(*sessionModel) - if err != nil { - return dto.ToHttpError(err) - } + claims, err := dto.GetClaimsFromToken(token) + if err != nil { + return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) } - token = t - break + sessionModel, err := h.persister.GetSessionPersister().Get(claims.SessionID) + if err != nil { + return fmt.Errorf("failed to get session from database: %w", err) + } + if sessionModel == nil { + continue + } + + // Update lastUsed field + sessionModel.LastUsed = time.Now().UTC() + err = h.persister.GetSessionPersister().Update(*sessionModel) + if err != nil { + return dto.ToHttpError(err) + } + + return c.JSON(http.StatusOK, dto.ValidateSessionResponse{ + IsValid: true, + Claims: claims, + ExpirationTime: &claims.Expiration, + UserID: &claims.Subject, + }) } } - if token != nil { - expirationTime := token.Expiration() - userID := uuid.FromStringOrNil(token.Subject()) - return c.JSON(http.StatusOK, dto.ValidateSessionResponse{ - IsValid: true, - ExpirationTime: &expirationTime, - UserID: &userID, - }) - } else { - return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) - } + return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) } func (h *SessionHandler) ValidateSessionFromBody(c echo.Context) error { @@ -110,39 +94,31 @@ func (h *SessionHandler) ValidateSessionFromBody(c echo.Context) error { return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) } - if h.cfg.Session.ServerSide.Enabled { - // check that the session id is stored in the database - sessionId, ok := token.Get("session_id") - if !ok { - return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) - } - sessionID, err := uuid.FromString(sessionId.(string)) - if err != nil { - return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) - } + claims, err := dto.GetClaimsFromToken(token) + if err != nil { + return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) + } - sessionModel, err := h.persister.GetSessionPersister().Get(sessionID) - if err != nil { - return dto.ToHttpError(err) - } + sessionModel, err := h.persister.GetSessionPersister().Get(claims.SessionID) + if err != nil { + return dto.ToHttpError(err) + } - if sessionModel == nil { - return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) - } + if sessionModel == nil { + return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false}) + } - // update lastUsed field - sessionModel.LastUsed = time.Now().UTC() - err = h.persister.GetSessionPersister().Update(*sessionModel) - if err != nil { - return dto.ToHttpError(err) - } + // update lastUsed field + sessionModel.LastUsed = time.Now().UTC() + err = h.persister.GetSessionPersister().Update(*sessionModel) + if err != nil { + return dto.ToHttpError(err) } - expirationTime := token.Expiration() - userID := uuid.FromStringOrNil(token.Subject()) return c.JSON(http.StatusOK, dto.ValidateSessionResponse{ IsValid: true, - ExpirationTime: &expirationTime, - UserID: &userID, + Claims: claims, + ExpirationTime: &claims.Expiration, + UserID: &claims.Subject, }) }