Skip to content

Commit

Permalink
feat: enhance session response
Browse files Browse the repository at this point in the history
  • Loading branch information
bjoern-m committed Dec 12, 2024
1 parent 3961837 commit 9ec7c54
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 71 deletions.
68 changes: 66 additions & 2 deletions backend/dto/session.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package dto

import (
"encoding/json"
"fmt"
"github.com/gofrs/uuid"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/mileusna/useragent"
"github.com/teamhanko/hanko/backend/persistence/models"
"time"
Expand Down Expand Up @@ -33,10 +35,72 @@ func FromSessionModel(model models.Session, current bool) SessionData {
}
}

type Claims struct {
Subject uuid.UUID `json:"subject"`
IssuedAt *time.Time `json:"issued_at,omitempty"`
Expiration time.Time `json:"expiration"`
Audience []string `json:"audience,omitempty"`
Issuer *string `json:"issuer,omitempty"`
Email *EmailJwt `json:"email,omitempty"`
SessionID uuid.UUID `json:"session_id"`
}

type ValidateSessionResponse struct {
IsValid bool `json:"is_valid"`
IsValid bool `json:"is_valid"`
Claims *Claims `json:"claims,omitempty"`
// deprecated
ExpirationTime *time.Time `json:"expiration_time,omitempty"`
UserID *uuid.UUID `json:"user_id,omitempty"`
// deprecated
UserID *uuid.UUID `json:"user_id,omitempty"`
}

func GetClaimsFromToken(token jwt.Token) (*Claims, error) {
claims := &Claims{}

if subject := token.Subject(); len(subject) > 0 {
s, err := uuid.FromString(subject)
if err != nil {
return nil, fmt.Errorf("'subject' is not a uuid: %w", err)
}
claims.Subject = s
}

if sessionID, valid := token.Get("session_id"); valid {
s, err := uuid.FromString(sessionID.(string))
if err != nil {
return nil, fmt.Errorf("'session_id' is not a uuid: %w", err)
}
claims.SessionID = s
}

if issuedAt := token.IssuedAt(); !issuedAt.IsZero() {
claims.IssuedAt = &issuedAt
}

if audience := token.Audience(); len(audience) > 0 {
claims.Audience = audience
}

if issuer := token.Issuer(); len(issuer) > 0 {
claims.Issuer = &issuer
}

if email, valid := token.Get("email"); valid {
if data, ok := email.(map[string]interface{}); ok {
jsonData, err := json.Marshal(data)
if err != nil {
return nil, fmt.Errorf("failed to marshal 'email' claim: %w", err)
}
err = json.Unmarshal(jsonData, &claims.Email)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal 'email' claim: %w", err)
}
}
}

claims.Expiration = token.Expiration()

return claims, nil
}

type ValidateSessionRequest struct {
Expand Down
114 changes: 45 additions & 69 deletions backend/handler/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package handler

import (
"fmt"
"github.com/gofrs/uuid"
echojwt "github.com/labstack/echo-jwt/v4"
"github.com/labstack/echo/v4"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/teamhanko/hanko/backend/config"
"github.com/teamhanko/hanko/backend/dto"
"github.com/teamhanko/hanko/backend/persistence"
Expand Down Expand Up @@ -36,61 +34,47 @@ func (h *SessionHandler) ValidateSession(c echo.Context) error {
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false})
}

var token jwt.Token
for _, extractor := range extractors {
auths, extractorErr := extractor(c)
if extractorErr != nil {
continue
}
for _, auth := range auths {
t, tokenErr := h.sessionManager.Verify(auth)
token, tokenErr := h.sessionManager.Verify(auth)
if tokenErr != nil {
continue
}

if h.cfg.Session.ServerSide.Enabled {
// check that the session id is stored in the database
sessionId, ok := t.Get("session_id")
if !ok {
continue
}
sessionID, err := uuid.FromString(sessionId.(string))
if err != nil {
continue
}

sessionModel, err := h.persister.GetSessionPersister().Get(sessionID)
if err != nil {
return fmt.Errorf("failed to get session from database: %w", err)
}
if sessionModel == nil {
continue
}

// Update lastUsed field
sessionModel.LastUsed = time.Now().UTC()
err = h.persister.GetSessionPersister().Update(*sessionModel)
if err != nil {
return dto.ToHttpError(err)
}
claims, err := dto.GetClaimsFromToken(token)
if err != nil {
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false})
}

token = t
break
sessionModel, err := h.persister.GetSessionPersister().Get(claims.SessionID)
if err != nil {
return fmt.Errorf("failed to get session from database: %w", err)
}
if sessionModel == nil {
continue
}

// Update lastUsed field
sessionModel.LastUsed = time.Now().UTC()
err = h.persister.GetSessionPersister().Update(*sessionModel)
if err != nil {
return dto.ToHttpError(err)
}

return c.JSON(http.StatusOK, dto.ValidateSessionResponse{
IsValid: true,
Claims: claims,
ExpirationTime: &claims.Expiration,
UserID: &claims.Subject,
})
}
}

if token != nil {
expirationTime := token.Expiration()
userID := uuid.FromStringOrNil(token.Subject())
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{
IsValid: true,
ExpirationTime: &expirationTime,
UserID: &userID,
})
} else {
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false})
}
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false})
}

func (h *SessionHandler) ValidateSessionFromBody(c echo.Context) error {
Expand All @@ -110,39 +94,31 @@ func (h *SessionHandler) ValidateSessionFromBody(c echo.Context) error {
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false})
}

if h.cfg.Session.ServerSide.Enabled {
// check that the session id is stored in the database
sessionId, ok := token.Get("session_id")
if !ok {
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false})
}
sessionID, err := uuid.FromString(sessionId.(string))
if err != nil {
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false})
}
claims, err := dto.GetClaimsFromToken(token)
if err != nil {
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false})
}

sessionModel, err := h.persister.GetSessionPersister().Get(sessionID)
if err != nil {
return dto.ToHttpError(err)
}
sessionModel, err := h.persister.GetSessionPersister().Get(claims.SessionID)
if err != nil {
return dto.ToHttpError(err)
}

if sessionModel == nil {
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false})
}
if sessionModel == nil {
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{IsValid: false})
}

// update lastUsed field
sessionModel.LastUsed = time.Now().UTC()
err = h.persister.GetSessionPersister().Update(*sessionModel)
if err != nil {
return dto.ToHttpError(err)
}
// update lastUsed field
sessionModel.LastUsed = time.Now().UTC()
err = h.persister.GetSessionPersister().Update(*sessionModel)
if err != nil {
return dto.ToHttpError(err)
}

expirationTime := token.Expiration()
userID := uuid.FromStringOrNil(token.Subject())
return c.JSON(http.StatusOK, dto.ValidateSessionResponse{
IsValid: true,
ExpirationTime: &expirationTime,
UserID: &userID,
Claims: claims,
ExpirationTime: &claims.Expiration,
UserID: &claims.Subject,
})
}

0 comments on commit 9ec7c54

Please sign in to comment.