diff --git a/go.mod b/go.mod index 4a94b96c..af5a193b 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,8 @@ require ( github.com/fission/fission v1.19.0 github.com/fxamacker/cbor/v2 v2.4.0 github.com/go-chi/chi/v5 v5.0.7 + github.com/go-jose/go-jose/v3 v3.0.1 + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/golang-migrate/migrate/v4 v4.16.2 github.com/google/uuid v1.3.0 github.com/gorilla/schema v1.2.0 diff --git a/go.sum b/go.sum index c5aedaeb..0a6d784c 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,8 @@ github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= +github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= @@ -118,6 +120,8 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/godbus/dbus/v5 v5.0.6/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= @@ -151,6 +155,7 @@ github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5a github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -307,6 +312,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -347,6 +353,7 @@ go.uber.org/zap v1.19.0/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= diff --git a/pkg/auth/errors.go b/pkg/auth/errors.go new file mode 100644 index 00000000..54af6b3d --- /dev/null +++ b/pkg/auth/errors.go @@ -0,0 +1,19 @@ +package auth + +import ( + "net/http" + + "sensorbucket.nl/sensorbucket/internal/web" +) + +var ( + // Authorization errors + ErrUnauthorized = web.NewError(http.StatusUnauthorized, "Unauthorized", "UNAUTHORIZED") + ErrNoTenantIDFound = web.NewError(http.StatusForbidden, "Forbidden", "FORBIDDEN") + ErrNoPermissions = web.NewError(http.StatusForbidden, "Forbidden", "FORBIDDEN") + ErrPermissionsNotGranted = web.NewError(http.StatusForbidden, "Forbidden", "FORBIDDEN") + ErrNoUserID = web.NewError(http.StatusForbidden, "Forbidden", "FORBIDDEN") + + // Request and server errors + ErrAuthHeaderInvalidFormat = web.NewError(http.StatusBadRequest, "Authorization header must be formatted as 'Bearer {token}'", "AUTH_HEADER_INVALID_FORMAT") +) diff --git a/pkg/auth/middleware.go b/pkg/auth/middleware.go new file mode 100644 index 00000000..69294ef4 --- /dev/null +++ b/pkg/auth/middleware.go @@ -0,0 +1,169 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "strings" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/golang-jwt/jwt" + "sensorbucket.nl/sensorbucket/internal/web" +) + +type ctxKey int + +type claims struct { + TenantID int64 `json:"tid"` + Permissions []permission `json:"perms"` + UserID int64 `json:"uid"` + Expiration int64 `json:"exp"` +} + +func (c *claims) Valid() error { + for _, permission := range c.Permissions { + if permission.Valid() != nil { + return fmt.Errorf("invalid permissions") + } + } + if c.TenantID > 0 && c.UserID > 0 && c.Expiration > time.Now().Unix() { + return nil + } + return fmt.Errorf("claims not valid") +} + +type jwksClient interface { + Get() (jose.JSONWebKeySet, error) +} + +type jwksHttpClient struct { + issuer string + httpClient http.Client +} + +func (c *jwksHttpClient) Get() (jose.JSONWebKeySet, error) { + res, err := c.httpClient.Get(fmt.Sprintf("%s/.well-known/jwks.json", c.issuer)) + + if err != nil { + return jose.JSONWebKeySet{}, fmt.Errorf("failed to fetch jwks: %w", err) + } + var jwks jose.JSONWebKeySet + if err := json.NewDecoder(res.Body).Decode(&jwks); err != nil { + return jose.JSONWebKeySet{}, fmt.Errorf("failed to decode jwks: %w", err) + } + return jwks, nil +} + +type contextBuilder struct { + c context.Context +} + +func (cb *contextBuilder) With(key ctxKey, value any) *contextBuilder { + cb.c = context.WithValue(cb.c, key, value) + return cb +} + +func (cb *contextBuilder) Finish() context.Context { + return cb.c +} + +const ( + ctxUserID ctxKey = iota + ctxCurrentTenantID + ctxPermissions +) + +func NewJWKSHttpClient(issuer string) *jwksHttpClient { + return &jwksHttpClient{ + issuer: issuer, + httpClient: http.Client{}, + } +} + +func Protect() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, tenantIDPresent := fromRequestContext[[]int64](r.Context(), ctxCurrentTenantID) + _, permissionsPresent := fromRequestContext[[]permission](r.Context(), ctxPermissions) + _, userIDPresent := fromRequestContext[int64](r.Context(), ctxUserID) + if tenantIDPresent && permissionsPresent && userIDPresent { + // All required authentication values are present, allow the request + next.ServeHTTP(w, r) + return + } + web.HTTPError(w, ErrUnauthorized) + }) + } +} + +// Authentication middleware for checking the validity of any present JWT +// Checks if the JWT is signed using the given secret +// Serves the next HTTP handler if there is no JWT or if the JWT is OK +// Anonymous requests are allowed by this handler +func Authenticate(keyClient jwksClient) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth == "" { + // Allow anonymous requests + next.ServeHTTP(w, r) + return + } + tokenStr, ok := strings.CutPrefix(auth, "Bearer ") + if !ok { + web.HTTPError(w, ErrAuthHeaderInvalidFormat) + return + } + + // Retrieve the JWT and ensure it was signed by us + c := claims{} + token, err := jwt.ParseWithClaims(tokenStr, &c, validateJWTFunc(keyClient)) + if err == nil && token.Valid { + // JWT itself is validated, pass it to the actual endpoint for further authorization + // First fill the context with user information + cb := contextBuilder{c: r.Context()} + next.ServeHTTP(w, r.WithContext( + cb. + With(ctxCurrentTenantID, []int64{c.TenantID}). + With(ctxUserID, c.UserID). + With(ctxPermissions, c.Permissions). + Finish())) + return + } + log.Printf("[Error] authentication failed err: %s", err) + web.HTTPError(w, ErrUnauthorized) + }) + } +} + +func validateJWTFunc(jwksClient jwksClient) func(token *jwt.Token) (any, error) { + return func(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + // Retrieve JWKS + jwks, err := jwksClient.Get() + if err != nil { + return nil, fmt.Errorf("failed to retrieve jwks: %w", err) + } + + // Look for the key as indicated by the token key id + kid, ok := token.Header["kid"].(string) + if !ok { + return nil, fmt.Errorf("no kid in token") + } + keys := jwks.Key(kid) + if len(keys) == 0 { + return nil, fmt.Errorf("no keys found for token") + } + key := keys[0] + if key.Algorithm != token.Method.Alg() { + return nil, fmt.Errorf("key alg differs from token alg: %s vs %s", key.Algorithm, token.Method.Alg()) + } + return key.Public().Key, nil + } +} diff --git a/pkg/auth/middleware_test.go b/pkg/auth/middleware_test.go new file mode 100644 index 00000000..f85a238d --- /dev/null +++ b/pkg/auth/middleware_test.go @@ -0,0 +1,488 @@ +package auth + +import ( + "context" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/golang-jwt/jwt" + "github.com/stretchr/testify/assert" +) + +// test jwks is unreachable +func TestAuthenticateWellKnownUnreachable(t *testing.T) { + type testCase struct { + token string + expectedStatusCode int + } + + // Arrange + client := jwksClientMock{ + GetFunc: func() (jose.JSONWebKeySet, error) { + return jose.JSONWebKeySet{}, fmt.Errorf("connection refused") + }, + } + auth := Authenticate(&client) + s := http.ServeMux{} + s.Handle("/", auth(nil)) + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + token := createToken(jwt.MapClaims{ + "tid": 11, + "perms": []string{ + "READ_DEVICES", + "READ_API_KEYS", + }, + "uid": 431, + "exp": time.Now().Add(time.Hour * 24).Unix(), + }) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Act + s.ServeHTTP(rr, req) + + // Assert + assert.Equal(t, 401, rr.Result().StatusCode) +} + +func TestProtectAndAuthenticatePassClaimsToNext(t *testing.T) { + + type testCase struct { + token string + expectedStatusCode int + } + + // Arrange + protect := Protect() + client := jwksClientMock{ + GetFunc: func() (jose.JSONWebKeySet, error) { + return jwks(), nil + }, + } + auth := Authenticate(&client) + next := HandlerMock{ + ServeHTTPFunc: func(responseWriter http.ResponseWriter, request *http.Request) { + assert.Equal(t, context.WithValue(context.WithValue(context.WithValue(context.Background(), ctxCurrentTenantID, []int64{11}), ctxUserID, int64(431)), ctxPermissions, []permission{ + READ_DEVICES, + READ_API_KEYS, + }), request.Context()) + }, + } + s := http.ServeMux{} + s.Handle("/", auth(protect(&next))) + + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + token := createToken(jwt.MapClaims{ + "tid": 11, + "perms": []string{ + "READ_DEVICES", + "READ_API_KEYS", + }, + "uid": 431, + "exp": time.Now().Add(time.Hour * 24).Unix(), + }) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Act + s.ServeHTTP(rr, req) + + // Assert + assert.Equal(t, 200, rr.Result().StatusCode) + assert.Len(t, next.ServeHTTPCalls(), 1) +} + +func TestProtect(t *testing.T) { + type testCase struct { + values map[ctxKey]interface{} + expectedStatusCode int + expectedNextCalls int + } + scenarios := map[string]testCase{ + "all required values present": { + values: map[ctxKey]interface{}{ + ctxCurrentTenantID: []int64{12, 54, 13}, + ctxPermissions: []permission{READ_API_KEYS}, + ctxUserID: int64(124), + }, + expectedStatusCode: 200, + expectedNextCalls: 1, + }, + "tid is missing": { + values: map[ctxKey]interface{}{ + ctxPermissions: []permission{READ_API_KEYS}, + ctxUserID: int64(124), + }, + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "perms is missing": { + values: map[ctxKey]interface{}{ + ctxCurrentTenantID: []int64{12, 54, 13}, + ctxUserID: int64(124), + }, + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "uid is missing": { + values: map[ctxKey]interface{}{ + ctxCurrentTenantID: []int64{12, 54, 13}, + ctxPermissions: []permission{READ_API_KEYS}, + }, + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "all required values are missing": { + values: map[ctxKey]interface{}{}, + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "tid is wrong type": { + values: map[ctxKey]interface{}{ + ctxCurrentTenantID: "123", // should be []int64! + ctxPermissions: []permission{READ_API_KEYS}, + ctxUserID: int64(124), + }, + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "perms is wrong type": { + values: map[ctxKey]interface{}{ + ctxCurrentTenantID: []int64{12, 54, 13}, + ctxPermissions: 54325, + ctxUserID: int64(124), + }, + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "uid is wrong type": { + values: map[ctxKey]interface{}{ + ctxCurrentTenantID: []int64{12, 54, 13}, + ctxPermissions: []permission{READ_API_KEYS}, + ctxUserID: "asdasdsad", + }, + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "tid is nil": { + values: map[ctxKey]interface{}{ + ctxCurrentTenantID: nil, + ctxPermissions: []permission{READ_API_KEYS}, + ctxUserID: int64(124), + }, + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "perms is nil": { + values: map[ctxKey]interface{}{ + ctxCurrentTenantID: []int64{12, 54, 13}, + ctxPermissions: nil, + ctxUserID: int64(124), + }, + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "uid is nil": { + values: map[ctxKey]interface{}{ + ctxCurrentTenantID: []int64{12, 54, 13}, + ctxPermissions: []permission{READ_API_KEYS}, + ctxUserID: nil, + }, + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + } + + for scene, cfg := range scenarios { + t.Run(scene, func(t *testing.T) { + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + rr := httptest.NewRecorder() + ctx := testAccumulateContext(context.Background(), cfg.values) + + next := HandlerMock{ + ServeHTTPFunc: func(responseWriter http.ResponseWriter, request *http.Request) {}, + } + + handler := Protect() + s := http.ServeMux{} + s.Handle("/", handler(&next)) + + // Act + s.ServeHTTP(rr, req.WithContext(ctx)) + + // Assert + assert.Equal(t, cfg.expectedStatusCode, rr.Result().StatusCode) + assert.Len(t, next.ServeHTTPCalls(), cfg.expectedNextCalls) + }) + } +} + +func TestAuthenticate(t *testing.T) { + in24Hours := time.Now().Add(time.Hour * 24).Unix() + var nilSlice []permission + type testCase struct { + authHeader string + expectedStatusCode int + expectedNextCalls int + expectedContext context.Context + } + scenarios := map[string]testCase{ + "auth header is invalid": { + authHeader: "blabla", + expectedStatusCode: 400, + expectedNextCalls: 0, + }, + "bearer token is invalid": { + authHeader: "Bearer blabla", + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "anonymous request is done": { + authHeader: "", + expectedStatusCode: 200, + expectedNextCalls: 1, + expectedContext: context.Background(), + }, + "bearer token is valid and contains all claims": { + authHeader: fmt.Sprintf("Bearer %s", createToken( + jwt.MapClaims{ + "tid": 11, + "perms": []string{ + "READ_DEVICES", + "READ_API_KEYS", + }, + "uid": 431, + "exp": in24Hours, + }, + )), + expectedStatusCode: 200, + expectedNextCalls: 1, + expectedContext: context.WithValue(context.WithValue(context.WithValue(context.Background(), ctxCurrentTenantID, []int64{11}), ctxUserID, int64(431)), ctxPermissions, []permission{ + READ_DEVICES, + READ_API_KEYS, + }), + }, + "bearer token contains invalid permission": { + authHeader: fmt.Sprintf("Bearer %s", createToken( + jwt.MapClaims{ + "tid": 11, + "perms": []string{ + "READ_DEVICES", + "READ_API_KEYS", + "DOES_NOT_EXIST", + }, + "uid": 431, + "exp": in24Hours, + }, + )), + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "bearer token is valid and but claims are missing": { + authHeader: fmt.Sprintf("Bearer %s", createToken( + jwt.MapClaims{}, + )), + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "bearer token is valid but tid is missing": { + authHeader: fmt.Sprintf("Bearer %s", createToken( + jwt.MapClaims{ + "perms": []string{ + "READ_DEVICES", + "READ_API_KEYS", + }, + "uid": 431, + "exp": in24Hours, + }, + )), + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "bearer token is valid but perms is missing": { + authHeader: fmt.Sprintf("Bearer %s", createToken( + jwt.MapClaims{ + "tid": 11, + "uid": 431, + "exp": in24Hours, + }, + )), + expectedContext: context.WithValue(context.WithValue(context.WithValue(context.Background(), ctxCurrentTenantID, []int64{11}), ctxUserID, int64(431)), ctxPermissions, nilSlice), + expectedStatusCode: 200, + expectedNextCalls: 1, + }, + "bearer token is valid but uid is missing": { + authHeader: fmt.Sprintf("Bearer %s", createToken( + jwt.MapClaims{ + "tid": 11, + "perms": []string{ + "READ_DEVICES", + "READ_API_KEYS", + }, + "exp": in24Hours, + }, + )), + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + "bearer token is valid all claims are present but the token is expired": { + authHeader: fmt.Sprintf("Bearer %s", createToken( + jwt.MapClaims{ + "tid": 11, + "perms": []string{ + "READ_DEVICES", + "READ_API_KEYS", + }, + "uid": 431, + "exp": time.Now().Add(-time.Hour * 24).Unix(), + }, + )), + expectedStatusCode: 401, + expectedNextCalls: 0, + }, + } + + for scene, cfg := range scenarios { + t.Run(scene, func(t *testing.T) { + req, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Authorization", cfg.authHeader) + rr := httptest.NewRecorder() + + next := HandlerMock{ + ServeHTTPFunc: func(responseWriter http.ResponseWriter, request *http.Request) { + assert.Equal(t, cfg.expectedContext, request.Context()) + }, + } + + client := jwksClientMock{ + GetFunc: func() (jose.JSONWebKeySet, error) { + return jwks(), nil + }, + } + + handler := Authenticate(&client) + s := http.ServeMux{} + s.Handle("/", handler(&next)) + + // Act + s.ServeHTTP(rr, req) + + // Assert + assert.Equal(t, cfg.expectedStatusCode, rr.Result().StatusCode) + assert.Len(t, next.ServeHTTPCalls(), cfg.expectedNextCalls) + }) + } +} + +func jsonPrivateKey() any { + block, _ := pem.Decode([]byte(key)) + if block == nil { + panic("failed to parse PEM block containing the private key") + } + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + panic(err) + } + return privateKey +} + +func testAccumulateContext(ctx context.Context, values map[ctxKey]interface{}) context.Context { + for key, val := range values { + ctx = context.WithValue(ctx, key, val) + } + return ctx +} + +func createToken(claims jwt.MapClaims) string { + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header["kid"] = "test-key" + tokenString, err := token.SignedString(jsonPrivateKey()) + if err != nil { + panic(err) + } + return tokenString +} + +func jwks() jose.JSONWebKeySet { + var jwks jose.JSONWebKeySet + if err := json.NewDecoder(io.NopCloser(strings.NewReader(jsonWebKeySet))).Decode(&jwks); err != nil { + return jose.JSONWebKeySet{} + } + return jwks +} + +// Keys below are for testing purposes only!! +const key = `-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDNoUGP4FABzt8m +XO/uoSrQ/thVTHDG2Lb3pQLmC6BkgCzygBtO8eTeORNkQHirNKC47yk8mllF2RdJ +doHiDFyfRSa+V8AJv4KfF7Yb65J8a78yAcmnj6yTSQqM+7E2U7WMRTbYw9HyE4Zp +Xp42pPCGFArlsT6CkPSAL+eLIvVjbSsv71DIy6UsDsRuAiK+27JC9FEjqJst9fLk +QXqawC8gNeE0lXN91Wj62sP1a9i7D+MFD3p92UI2F3FNilOKrCPrsQI9Y5Il9qHK +u+HM3AJ//7Ym3/RN69jsBBZAclRaCiBlJhhczMkzxUffiLkxe1hiNhTZBOm7n0a2 +MmFyqDLNAgMBAAECggEAHXODquIzQ1cIUfvMp45wzfc6L9lfa7N9XTHGnQE8Sziq +d18OyjtODt/43Yp4XfkPLf2fF915fM4PjkeJacFggLVMS8XQrPS/dh7Ux+HxHJ3o +B/cGlVe4HW5AMxoXcxMBNSJyrRA64SOXxD63hVcRVfrH5scAj33IbxWtYZmzsLYf +1/TWaY5DEd3i67W65tDNzSVoCYu8Wsg5z6lmN5SJmxR1zjMyypCoGNdcm9Pa/vvq +Hb2xHKOX3Io4vSY2VTurWk9/iIfEVLuqiuq1s5dJ2vd3OCHslw2JshOGM+kGU9Lt +z6+lcBJcr7jPAPL8y4EMgs1oqsNBfUIXkr59UrYyZwKBgQDnwv3UHFvuTvKgAdqj +UW8fxuWoJ46KBBNAVCkuO8RNHpoFG5dsfHom8hLMPi0d/+9udN6k4Aac4AhpV4at +RFKjdHjBVsm06TKSf8fPGUselicWCBuqUFHH3Pi20Aw+i9R5aAxxYR12gOiZJW0D +eLxnHVKgtVwDFfg3JAR0dOik/wKBgQDjIqHrxh98mshdYUSmWFYYjCSONCsIoGmo +DVdl8LKfNgMzegRCcKjteUjESXipm5Z2uitiQWNpe0HR4bltCyAyzkenTBqOf+Yh +TfJTB94ko7RR22Xj71WeI9WRnCOINQXIvHNBSf3gYXccZjBV20cr9xEKq0qjS6YV +ZffCYoysMwKBgDjUQXVvdsNarHe7vKbrYvpBxTKUcIk7MpVFjct+cEYQyOeTum+p +njJKjX1ziZCfn1BQa/+1xylUbfuWsLlv1WurNakC5Pbtb68okhAgPaFEZFUsq8v5 +YfRGJN5+6WG02+bhMpvimlzigyZ6XN7LDjeiow4xKly/WFv9AvKjcCB1AoGAGThm +PFTSeDaDmwLK6aGTZcRh5rxaLuoI8VUR6ErSuqT3tAaPZIU37K5z6v+xezvAeExx +tsZF8Jd0Fob23OnIWHvZLvVfWYVQG1CZYKjV/MGEqzYuWSHhIt8dvr5Un7Irgz+R +mKVLoFeSL0AVi+L+Qx568PFWJ02mEmgxG49vyUsCgYBm6R13DaGv5mylpYc/CWbx +rF3IpRWYewlcO2xrgiCEvp+9Eh0epSuK/kKaEwwv90pMHReIrpcMujBOpUJT7/NZ +fJA0UGp5r4Z2az1b4i4sF70Uark9TatJ3XH7AcP3tFfo2TQeiST4qgKyx35iT/0r +mxiuHhps1ig5jCN3YGj2zQ== +-----END PRIVATE KEY----- +` + +const publicKey = `-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzaFBj+BQAc7fJlzv7qEq +0P7YVUxwxti296UC5gugZIAs8oAbTvHk3jkTZEB4qzSguO8pPJpZRdkXSXaB4gxc +n0UmvlfACb+Cnxe2G+uSfGu/MgHJp4+sk0kKjPuxNlO1jEU22MPR8hOGaV6eNqTw +hhQK5bE+gpD0gC/niyL1Y20rL+9QyMulLA7EbgIivtuyQvRRI6ibLfXy5EF6msAv +IDXhNJVzfdVo+trD9WvYuw/jBQ96fdlCNhdxTYpTiqwj67ECPWOSJfahyrvhzNwC +f/+2Jt/0TevY7AQWQHJUWgogZSYYXMzJM8VH34i5MXtYYjYU2QTpu59GtjJhcqgy +zQIDAQAB +-----END PUBLIC KEY----- +` + +const jsonWebKeySet = `{ + "keys":[ + { + "alg":"RS256", + "e":"AQAB", + "kid":"test-key", + "kty":"RSA", + "n":"zaFBj-BQAc7fJlzv7qEq0P7YVUxwxti296UC5gugZIAs8oAbTvHk3jkTZEB4qzSguO8pPJpZRdkXSXaB4gxcn0UmvlfACb-Cnxe2G-uSfGu_MgHJp4-sk0kKjPuxNlO1jEU22MPR8hOGaV6eNqTwhhQK5bE-gpD0gC_niyL1Y20rL-9QyMulLA7EbgIivtuyQvRRI6ibLfXy5EF6msAvIDXhNJVzfdVo-trD9WvYuw_jBQ96fdlCNhdxTYpTiqwj67ECPWOSJfahyrvhzNwCf_-2Jt_0TevY7AQWQHJUWgogZSYYXMzJM8VH34i5MXtYYjYU2QTpu59GtjJhcqgyzQ" + } + ] + }` diff --git a/pkg/auth/mock_test.go b/pkg/auth/mock_test.go new file mode 100644 index 00000000..972e6eef --- /dev/null +++ b/pkg/auth/mock_test.go @@ -0,0 +1,141 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package auth + +import ( + "github.com/go-jose/go-jose/v3" + "net/http" + "sync" +) + +// Ensure, that jwksClientMock does implement jwksClient. +// If this is not the case, regenerate this file with moq. +var _ jwksClient = &jwksClientMock{} + +// jwksClientMock is a mock implementation of jwksClient. +// +// func TestSomethingThatUsesjwksClient(t *testing.T) { +// +// // make and configure a mocked jwksClient +// mockedjwksClient := &jwksClientMock{ +// GetFunc: func() (jose.JSONWebKeySet, error) { +// panic("mock out the Get method") +// }, +// } +// +// // use mockedjwksClient in code that requires jwksClient +// // and then make assertions. +// +// } +type jwksClientMock struct { + // GetFunc mocks the Get method. + GetFunc func() (jose.JSONWebKeySet, error) + + // calls tracks calls to the methods. + calls struct { + // Get holds details about calls to the Get method. + Get []struct { + } + } + lockGet sync.RWMutex +} + +// Get calls GetFunc. +func (mock *jwksClientMock) Get() (jose.JSONWebKeySet, error) { + if mock.GetFunc == nil { + panic("jwksClientMock.GetFunc: method is nil but jwksClient.Get was just called") + } + callInfo := struct { + }{} + mock.lockGet.Lock() + mock.calls.Get = append(mock.calls.Get, callInfo) + mock.lockGet.Unlock() + return mock.GetFunc() +} + +// GetCalls gets all the calls that were made to Get. +// Check the length with: +// +// len(mockedjwksClient.GetCalls()) +func (mock *jwksClientMock) GetCalls() []struct { +} { + var calls []struct { + } + mock.lockGet.RLock() + calls = mock.calls.Get + mock.lockGet.RUnlock() + return calls +} + +// Ensure, that HandlerMock does implement Handler. +// If this is not the case, regenerate this file with moq. +var _ http.Handler = &HandlerMock{} + +// HandlerMock is a mock implementation of Handler. +// +// func TestSomethingThatUsesHandler(t *testing.T) { +// +// // make and configure a mocked Handler +// mockedHandler := &HandlerMock{ +// ServeHTTPFunc: func(responseWriter http.ResponseWriter, request *http.Request) { +// panic("mock out the ServeHTTP method") +// }, +// } +// +// // use mockedHandler in code that requires Handler +// // and then make assertions. +// +// } +type HandlerMock struct { + // ServeHTTPFunc mocks the ServeHTTP method. + ServeHTTPFunc func(responseWriter http.ResponseWriter, request *http.Request) + + // calls tracks calls to the methods. + calls struct { + // ServeHTTP holds details about calls to the ServeHTTP method. + ServeHTTP []struct { + // ResponseWriter is the responseWriter argument value. + ResponseWriter http.ResponseWriter + // Request is the request argument value. + Request *http.Request + } + } + lockServeHTTP sync.RWMutex +} + +// ServeHTTP calls ServeHTTPFunc. +func (mock *HandlerMock) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { + if mock.ServeHTTPFunc == nil { + panic("HandlerMock.ServeHTTPFunc: method is nil but Handler.ServeHTTP was just called") + } + callInfo := struct { + ResponseWriter http.ResponseWriter + Request *http.Request + }{ + ResponseWriter: responseWriter, + Request: request, + } + mock.lockServeHTTP.Lock() + mock.calls.ServeHTTP = append(mock.calls.ServeHTTP, callInfo) + mock.lockServeHTTP.Unlock() + mock.ServeHTTPFunc(responseWriter, request) +} + +// ServeHTTPCalls gets all the calls that were made to ServeHTTP. +// Check the length with: +// +// len(mockedHandler.ServeHTTPCalls()) +func (mock *HandlerMock) ServeHTTPCalls() []struct { + ResponseWriter http.ResponseWriter + Request *http.Request +} { + var calls []struct { + ResponseWriter http.ResponseWriter + Request *http.Request + } + mock.lockServeHTTP.RLock() + calls = mock.calls.ServeHTTP + mock.lockServeHTTP.RUnlock() + return calls +} diff --git a/pkg/auth/permissions.go b/pkg/auth/permissions.go new file mode 100644 index 00000000..026f8e74 --- /dev/null +++ b/pkg/auth/permissions.go @@ -0,0 +1,90 @@ +package auth + +import "fmt" + +const ( + // Device permissions + READ_DEVICES permission = "READ_DEVICES" + WRITE_DEVICES permission = "WRITE_DEVICES" + + // API Key permissions + READ_API_KEYS permission = "READ_API_KEYS" + WRITE_API_KEYS permission = "WRITE_API_KEYS" + + // Tenant permissions + READ_TENANTS permission = "READ_TENANTS" + WRITE_TENANTS permission = "WRITE_TENANTS" + + // Measurement permissions + READ_MEASUREMENTS permission = "READ_MEASUREMENTS" + WRITE_MEASUREMENTS permission = "WRITE_MEASUREMENTS" + + // Tracing permissions + READ_TRACING permission = "READ_TRACING" + + // User worker permissions + READ_USER_WORKERS permission = "READ_USER_WORKERS" + WRITE_USER_WORKERS permission = "WRITE_USER_WORKERS" +) + +var allowedPermissions = []permission{ + READ_DEVICES, + WRITE_DEVICES, + READ_API_KEYS, + WRITE_API_KEYS, + READ_TENANTS, + WRITE_TENANTS, + READ_MEASUREMENTS, + WRITE_MEASUREMENTS, + READ_TRACING, + READ_USER_WORKERS, + WRITE_USER_WORKERS, +} + +var SuperUserRole = role(allowedPermissions) + +type Role interface { + Permissions() []permission + HasPermissions(permission permission, permissions ...permission) bool +} + +type permission string + +func (p permission) String() string { + return string(p) +} + +func (p permission) Valid() error { + for _, allowed := range allowedPermissions { + if allowed == p { + return nil + } + } + return fmt.Errorf("%s is not a valid permission", p) +} + +type role []permission + +func (r role) HasPermissions(permission permission, permissions ...permission) bool { + permissions = append(permissions, permission) + for _, rolePermission := range r { + found := false + for _, p := range permissions { + if rolePermission == p { + found = true + } + } + if !found { + return false + } + } + return true +} + +func (r role) Permissions() []permission { + return r +} + +func NewRole(permissions ...permission) Role { + return role(permissions) +} diff --git a/pkg/auth/permissions_test.go b/pkg/auth/permissions_test.go new file mode 100644 index 00000000..97c40a7b --- /dev/null +++ b/pkg/auth/permissions_test.go @@ -0,0 +1,30 @@ +package auth + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPermissionsValid(t *testing.T) { + type testCase struct { + permission permission + expectedErr error + } + scenarios := map[string]testCase{ + "valid permission": { + permission: permission("WRITE_USER_WORKERS"), + expectedErr: nil, + }, + "invalid permission": { + permission: permission("WEIRD_PERMISSION"), + expectedErr: fmt.Errorf("WEIRD_PERMISSION is not a valid permission"), + }, + } + for scene, tc := range scenarios { + t.Run(scene, func(t *testing.T) { + assert.Equal(t, tc.expectedErr, tc.permission.Valid()) + }) + } +} diff --git a/pkg/auth/utils.go b/pkg/auth/utils.go new file mode 100644 index 00000000..f80df843 --- /dev/null +++ b/pkg/auth/utils.go @@ -0,0 +1,67 @@ +package auth + +import ( + "context" + + "github.com/samber/lo" +) + +// Checks if the given context contains said permissions +// returns nil if all is OK +func MustHavePermissions(c context.Context, perm permission, permissions ...permission) error { + permissions = append(permissions, perm) + permissionsFromContext, ok := fromRequestContext[[]permission](c, ctxPermissions) + if !ok { + return ErrNoPermissions + } + if lo.Every(permissionsFromContext, permissions) { + return nil + } + return ErrPermissionsNotGranted +} + +func HasRole(ctx context.Context, r Role) bool { + permissionsFromContext, ok := fromRequestContext[[]permission](ctx, ctxPermissions) + if !ok { + return false + } + if len(permissionsFromContext) == 0 { + return false + } + if len(permissionsFromContext) > 1 { + return r.HasPermissions(permissionsFromContext[0], permissionsFromContext...) + } + return r.HasPermissions(permissionsFromContext[0]) +} + +func GetUser(ctx context.Context) (int64, error) { + val, ok := fromRequestContext[int64](ctx, ctxUserID) + if !ok { + return -1, ErrNoUserID + } + return val, nil +} + +func GetTenants(ctx context.Context) ([]int64, error) { + val, ok := fromRequestContext[[]int64](ctx, ctxCurrentTenantID) + if !ok { + return nil, ErrNoTenantIDFound + } + return val, nil +} + +func HasPermissionsFor(ctx context.Context, tenantIDs ...int64) bool { + tenants, err := GetTenants(ctx) + if err != nil { + return false + } + if len(tenantIDs) == 0 { + return false + } + return lo.Every(tenants, tenantIDs) +} + +func fromRequestContext[T any](c context.Context, key ctxKey) (T, bool) { + val, ok := c.Value(key).(T) + return val, ok +} diff --git a/pkg/auth/utils_test.go b/pkg/auth/utils_test.go new file mode 100644 index 00000000..eaa4cc07 --- /dev/null +++ b/pkg/auth/utils_test.go @@ -0,0 +1,233 @@ +package auth + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestMustHavePermissionsRequestedPermissions(t *testing.T) { + // Arrange + type testCase struct { + permissionsInCtx []permission + permissionsInput []permission + expectedErr error + } + + scenarios := map[string]testCase{ + "no permissions present": { + permissionsInput: []permission{READ_API_KEYS, READ_DEVICES}, + permissionsInCtx: []permission{}, // empty! + expectedErr: ErrPermissionsNotGranted, + }, + "no permissions present in context": { + permissionsInput: []permission{READ_API_KEYS, READ_DEVICES}, + permissionsInCtx: nil, + expectedErr: ErrNoPermissions, + }, + "some requested permissions are missing": { + permissionsInput: []permission{READ_API_KEYS, READ_DEVICES, WRITE_API_KEYS, WRITE_DEVICES}, + permissionsInCtx: []permission{READ_API_KEYS, READ_DEVICES}, + expectedErr: ErrPermissionsNotGranted, + }, + "only 1 requested permission is missing": { + permissionsInput: []permission{READ_API_KEYS, READ_DEVICES, WRITE_API_KEYS, WRITE_DEVICES}, + permissionsInCtx: []permission{READ_API_KEYS, READ_DEVICES, WRITE_API_KEYS}, + expectedErr: ErrPermissionsNotGranted, + }, + "only 1 requested permission is present": { + permissionsInput: []permission{READ_API_KEYS, READ_DEVICES, WRITE_API_KEYS, WRITE_DEVICES}, + permissionsInCtx: []permission{READ_API_KEYS}, + expectedErr: ErrPermissionsNotGranted, + }, + "all requested permissions are present": { + permissionsInput: []permission{READ_API_KEYS, READ_DEVICES, WRITE_API_KEYS, WRITE_DEVICES}, + permissionsInCtx: []permission{READ_API_KEYS, READ_DEVICES, WRITE_API_KEYS, WRITE_DEVICES}, + expectedErr: nil, + }, + "all requested permissions are missing": { + permissionsInput: []permission{READ_API_KEYS, READ_DEVICES, WRITE_API_KEYS, WRITE_DEVICES}, + permissionsInCtx: []permission{READ_API_KEYS}, + expectedErr: ErrPermissionsNotGranted, + }, + "more permissions are present than are requested": { + permissionsInput: []permission{READ_API_KEYS, READ_DEVICES}, + permissionsInCtx: []permission{READ_API_KEYS, WRITE_API_KEYS, WRITE_DEVICES, READ_DEVICES}, + expectedErr: nil, + }, + } + for testC, cfg := range scenarios { + t.Run(testC, func(t *testing.T) { + ctx := context.Background() + if cfg.permissionsInCtx != nil { + ctx = context.WithValue(ctx, ctxPermissions, cfg.permissionsInCtx) + } + + // Act + err := MustHavePermissions(ctx, cfg.permissionsInput[0], cfg.permissionsInput[1:]...) + + // Assert + assert.ErrorIs(t, err, cfg.expectedErr) + }) + } +} + +func TestHasRole(t *testing.T) { + // Arrange + type testCase struct { + permissionsInCtx []permission + roleInput role + expectedRes bool + } + scenarios := map[string]testCase{ + "does not have requested role": { + permissionsInCtx: []permission{READ_DEVICES, WRITE_DEVICES}, + roleInput: role([]permission{READ_API_KEYS, WRITE_API_KEYS}), + expectedRes: false, + }, + "has only the requested role": { + permissionsInCtx: []permission{READ_API_KEYS, WRITE_API_KEYS}, + roleInput: role([]permission{READ_API_KEYS, WRITE_API_KEYS}), + expectedRes: true, + }, + "has requested role and more permissions": { + permissionsInCtx: []permission{READ_API_KEYS, READ_DEVICES, WRITE_API_KEYS, WRITE_DEVICES}, + roleInput: role([]permission{READ_API_KEYS, WRITE_API_KEYS}), + expectedRes: true, + }, + "has no permissions": { + permissionsInCtx: []permission{}, + roleInput: role([]permission{READ_API_KEYS, WRITE_API_KEYS}), + expectedRes: false, + }, + "has no permissions in context": { + permissionsInCtx: nil, + roleInput: role([]permission{READ_API_KEYS, WRITE_API_KEYS}), + expectedRes: false, + }, + "has only 1 role in context": { + permissionsInCtx: []permission{}, + roleInput: role([]permission{READ_API_KEYS}), + expectedRes: false, + }, + } + for scene, cfg := range scenarios { + t.Run(scene, func(t *testing.T) { + // Act + ctx := context.Background() + if cfg.permissionsInCtx != nil { + ctx = context.WithValue(ctx, ctxPermissions, cfg.permissionsInCtx) + } + result := HasRole(ctx, cfg.roleInput) + + // Assert + assert.Equal(t, cfg.expectedRes, result) + }) + } +} + +func TestGetTenants(t *testing.T) { + // Arrange + type testCase struct { + tenantsInContext []int64 + expectedRes []int64 + expectedErr error + } + + scenarios := map[string]testCase{ + "no tenants in context": { + tenantsInContext: nil, + expectedRes: nil, + expectedErr: ErrNoTenantIDFound, + }, + "no tenants": { + tenantsInContext: []int64{}, + expectedRes: []int64{}, + expectedErr: nil, + }, + "multiple tenants in context": { + tenantsInContext: []int64{541, 241, 21}, + expectedRes: []int64{541, 241, 21}, + expectedErr: nil, + }, + "only 1 tenant in context": { + tenantsInContext: []int64{143}, + expectedRes: []int64{143}, + expectedErr: nil, + }, + } + + for scene, cfg := range scenarios { + t.Run(scene, func(t *testing.T) { + ctx := context.Background() + if cfg.tenantsInContext != nil { + ctx = context.WithValue(ctx, ctxCurrentTenantID, cfg.tenantsInContext) + } + + // Act + result, err := GetTenants(ctx) + + // Assert + assert.Equal(t, cfg.expectedRes, result) + assert.Equal(t, cfg.expectedErr, err) + }) + } +} + +func TestHasPermissionsFor(t *testing.T) { + // Arrange + type testCase struct { + tenantsInContext []int64 + tenantsInput []int64 + expectedRes bool + } + + scenarios := map[string]testCase{ + "no tenants in context": { + tenantsInContext: nil, + tenantsInput: []int64{123, 54, 21, 53}, + expectedRes: false, + }, + "no tenants": { + tenantsInContext: []int64{}, + tenantsInput: []int64{123, 54, 21, 53}, + expectedRes: false, + }, + "has permissions for 1 tenant": { + tenantsInContext: []int64{123}, + tenantsInput: []int64{123, 54, 21, 53}, + expectedRes: false, + }, + "has permissions for some tenants": { + tenantsInContext: []int64{123, 54}, + tenantsInput: []int64{123, 54, 21, 53}, + expectedRes: false, + }, + "has permissions for all tenants": { + tenantsInContext: []int64{123, 54, 21, 53}, + tenantsInput: []int64{123, 54, 21, 53}, + expectedRes: true, + }, + "has permissions for 1 tenant and 1 tenant is requested": { + tenantsInContext: []int64{123}, + tenantsInput: []int64{123}, + expectedRes: true, + }, + } + + for scene, cfg := range scenarios { + t.Run(scene, func(t *testing.T) { + ctx := context.Background() + if cfg.tenantsInContext != nil { + ctx = context.WithValue(ctx, ctxCurrentTenantID, cfg.tenantsInContext) + } + + // Act + result := HasPermissionsFor(ctx, cfg.tenantsInput...) + + // Assert + assert.Equal(t, cfg.expectedRes, result) + }) + } +}