Skip to content

Commit

Permalink
auth/clientcredentials: add schema-based scope enforcement interceptor (
Browse files Browse the repository at this point in the history
#61)

Adds an interceptor, `clientcredentials.Interceptor`, that enforces
required SAMS scopes based on a proto schema extension. This will reduce
boilerplate in RPC implementations and reduce the chance of
mistakes/out-of-sync problems between a schema and the implementation

Usage example:

```proto
extend google.protobuf.MethodOptions {
	// The SAMS scopes required to use this RPC.
	//
	// The range 50000-99999 is reserved for internal use within individual organizations
	// so you can use numbers in this range freely for in-house applications.
	repeated string sams_required_scopes = 50001;
}
```

Allows you to set required scopes as method options:

```proto
rpc GetUserRoles(GetUserRolesRequest) returns (GetUserRolesResponse) {
	option (sams_required_scopes) = "sams::user.roles::read";
};
```

This generates `E_SamsRequiredScopes` that can be used to point to where
we can extract `sams_required_scopes`.

## Test plan

Unit tests

---------

Co-authored-by: Joe Chen <[email protected]>
  • Loading branch information
bobheadxi and unknwon authored Sep 3, 2024
1 parent 8ac8aea commit 35e9505
Show file tree
Hide file tree
Showing 5 changed files with 594 additions and 145 deletions.
232 changes: 232 additions & 0 deletions auth/clientcredentials/clientcredentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package clientcredentials

import (
"context"
"net/http"
"strings"

"connectrpc.com/connect"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
otelcodes "go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoimpl"

"github.com/sourcegraph/log"
sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go"
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"
"github.com/sourcegraph/sourcegraph/lib/errors"
)

var tracer = otel.Tracer("sams/auth/clientcredentials")

type TokenIntrospector interface {
// IntrospectToken takes a SAMS access token and returns relevant metadata.
// This is generally implemented by *sams.TokensServiceV1.
//
// 🚨 SECURITY: SAMS will return a successful result if the token is valid, but
// is no longer active. It is critical that the caller not honor tokens where
// `.Active == false`.
IntrospectToken(ctx context.Context, token string) (*sams.IntrospectTokenResponse, error)
}

// See clientcredentials.NewInterceptor.
type Interceptor struct {
logger log.Logger
introspector TokenIntrospector
extension *protoimpl.ExtensionInfo
}

// NewInterceptor creates a serverside handler interceptor that ensures every
// incoming request has a valid client credential token with the required scopes
// indicated in the RPC method options. When used, required scopes CANNOT be
// empty - if no scopes are required, declare a separate service that does not
// use this interceptor.
//
// To declare required SAMS scopes in your RPC, add the following to your proto
// schema:
//
// extend google.protobuf.MethodOptions {
// // The SAMS scopes required to use this RPC.
// //
// // The range 50000-99999 is reserved for internal use within individual organizations
// // so you can use numbers in this range freely for in-house applications.
// repeated string sams_required_scopes = 50001;
// }
//
// In your RPCs, add the `(sams_required_scopes)` option as a comma-delimited
// list:
//
// rpc GetUserRoles(GetUserRolesRequest) returns (GetUserRolesResponse) {
// option (sams_required_scopes) = "sams::user.roles::read";
// };
//
// This will generate a variable called `E_SamsRequiredScopes` in your generated
// proto bindings. This variable should be provided to NewInterceptor to allow
// it to identify where to source the required scopes from.
//
// The provided logger is used to record internal-server errors.
func NewInterceptor(
logger log.Logger,
introspector TokenIntrospector,
methodOptionsRequiredScopesExtension *protoimpl.ExtensionInfo,
) *Interceptor {
return &Interceptor{
logger: logger.Scoped("clientcredentials"),
introspector: introspector,
extension: methodOptionsRequiredScopesExtension,
}
}

var _ connect.Interceptor = (*Interceptor)(nil)

func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if req.Spec().IsClient {
return next(ctx, req) // no-op for clients
}
requiredScopes, err := extractSchemaRequiredScopes(req.Spec(), i.extension)
if err != nil {
return nil, internalError(ctx, i.logger, err, "internal schema error") // invalid schema is internal error
}
info, err := i.requireScope(ctx, req.Header(), requiredScopes)
if err != nil {
return nil, err
}
return next(WithClientInfo(ctx, info), req)
}
}

func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
return next(ctx, spec) // no-op for clients
}
}

func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
if conn.Spec().IsClient {
return next(ctx, conn) // no-op for clients
}
requiredScopes, err := extractSchemaRequiredScopes(conn.Spec(), i.extension)
if err != nil {
return internalError(ctx, i.logger, err, "internal schema error") // invalid schema is internal error
}
info, err := i.requireScope(ctx, conn.RequestHeader(), requiredScopes)
if err != nil {
return err
}
return next(WithClientInfo(ctx, info), conn)
}
}

// RequireScope ensures the request context has a valid SAMS M2M token
// with requiredScope. It returns a ConnectRPC status error suitable to be
// returned directly from a ConnectRPC implementation.
func (i *Interceptor) requireScope(ctx context.Context, headers http.Header, requiredScopes scopes.Scopes) (_ *ClientInfo, err error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "clientcredentials.requireScope")
defer func() {
if err != nil {
span.RecordError(err)
span.SetStatus(otelcodes.Error, "check failed")
}
span.End()
}()

token, err := extractBearerContents(headers)
if err != nil {
return nil, connect.NewError(connect.CodeUnauthenticated,
errors.Wrap(err, "invalid authorization header"))
}

result, err := i.introspector.IntrospectToken(ctx, token)
if err != nil {
return nil, internalError(ctx, i.logger, err, "unable to validate token")
}
span.SetAttributes(
attribute.String("client_id", result.ClientID),
attribute.String("token_expires_at", result.ExpiresAt.String()),
attribute.StringSlice("token_scopes", scopes.ToStrings(result.Scopes)))
info := &ClientInfo{
ClientID: result.ClientID,
TokenExpiresAt: result.ExpiresAt,
TokenScopes: result.Scopes,
}

// Active encapsulates whether the token is active, including expiration.
if !result.Active {
// Record detailed error in span, and return an opaque one
span.SetAttributes(attribute.String("full_error", "inactive token"))
return info, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
}

// Check for our required scope.
for _, required := range requiredScopes {
if !result.Scopes.Match(required) {
err = errors.Newf("got scopes %+v, required: %+v", result.Scopes, requiredScopes)
span.SetAttributes(attribute.String("full_error", err.Error()))
return info, connect.NewError(connect.CodePermissionDenied,
errors.Wrap(err, "insufficient scopes"))
}
}

return info, nil
}

func extractSchemaRequiredScopes(spec connect.Spec, extension *protoimpl.ExtensionInfo) (scopes.Scopes, error) {
method, ok := spec.Schema.(protoreflect.MethodDescriptor)
if !ok {
return nil, errors.Newf("expected protoreflect.MethodDescriptor, got %T", spec.Schema)
}

value := method.Options().ProtoReflect().Get(extension.TypeDescriptor())
if !value.IsValid() {
return nil, errors.Newf("extension field %s not valid", extension.TypeDescriptor().FullName())
}
list := value.List()
if list.Len() == 0 {
return nil, errors.Newf("extension field %s cannot be empty", extension.TypeDescriptor().FullName())
}

requiredScopes := make(scopes.Scopes, list.Len())
for i := 0; i < list.Len(); i++ {
requiredScopes[i] = scopes.Scope(list.Get(i).String())
}
return requiredScopes, nil
}

func extractBearerContents(h http.Header) (string, error) {
authHeader := h.Get("Authorization")
if authHeader == "" {
return "", errors.New("no token provided in Authorization header")
}
typ := strings.SplitN(authHeader, " ", 2)
if len(typ) != 2 {
return "", errors.New("token type missing in Authorization header")
}
if !strings.EqualFold(typ[0], "bearer") {
return "", errors.Newf("invalid token type %s in Authorization header", typ[0])
}
return typ[1], nil
}

// internalError logs an error, adds it to the trace, and returns a connect
// error with a safe message.
func internalError(ctx context.Context, logger log.Logger, err error, safeMsg string) error {
trace.SpanFromContext(ctx).
SetAttributes(
attribute.String("full_error", err.Error()),
)
logger.WithTrace(log.TraceContext{
TraceID: trace.SpanContextFromContext(ctx).TraceID().String(),
SpanID: trace.SpanContextFromContext(ctx).SpanID().String(),
}).
AddCallerSkip(1).
Error(safeMsg,
log.String("code", connect.CodeInternal.String()),
log.Error(err),
)
return connect.NewError(connect.CodeInternal, errors.New(safeMsg))
}
111 changes: 111 additions & 0 deletions auth/clientcredentials/clientcredentials_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package clientcredentials

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"connectrpc.com/connect"
"github.com/hexops/autogold/v2"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"

"github.com/sourcegraph/log/logtest"
sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go"
clientsv1 "github.com/sourcegraph/sourcegraph-accounts-sdk-go/clients/v1"
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/clients/v1/clientsv1connect"
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"
"github.com/sourcegraph/sourcegraph/lib/errors"
)

type mockTokenIntrospector struct {
response *sams.IntrospectTokenResponse
}

func (m *mockTokenIntrospector) IntrospectToken(ctx context.Context, token string) (*sams.IntrospectTokenResponse, error) {
return m.response, nil
}

func TestInterceptor(t *testing.T) {
// All tests based on UsersService.GetUser()
for _, tc := range []struct {
name string
token *sams.IntrospectTokenResponse
wantError autogold.Value
wantLogs autogold.Value
}{{
name: "inactive token",
token: &sams.IntrospectTokenResponse{
Active: false,
},
wantError: autogold.Expect("permission_denied: permission denied"),
wantLogs: autogold.Expect([]string{}),
}, {
name: "insufficient scopes",
token: &sams.IntrospectTokenResponse{
Active: true,
},
wantError: autogold.Expect("permission_denied: insufficient scopes: got scopes [], required: [profile]"),
wantLogs: autogold.Expect([]string{}),
}, {
name: "matches required scope",
token: &sams.IntrospectTokenResponse{
Active: true,
Scopes: scopes.Scopes{"profile"},
},
wantError: autogold.Expect(nil), // should not error!
wantLogs: autogold.Expect([]string{}),
}, {
name: "wrong scope",
token: &sams.IntrospectTokenResponse{
Active: true,
Scopes: scopes.Scopes{"not-a-scope"},
},
wantError: autogold.Expect("permission_denied: insufficient scopes: got scopes [not-a-scope], required: [profile]"),
wantLogs: autogold.Expect([]string{}),
}} {
t.Run(tc.name, func(t *testing.T) {
logger, exportLogs := logtest.Captured(t)
interceptor := NewInterceptor(
logger,
&mockTokenIntrospector{
response: tc.token,
},
clientsv1.E_SamsRequiredScopes,
)
mux := http.NewServeMux()
mux.Handle(
clientsv1connect.NewUsersServiceHandler(clientsv1connect.UnimplementedUsersServiceHandler{},
connect.WithInterceptors(interceptor)),
)
srv := httptest.NewServer(mux)
c := clientsv1connect.NewUsersServiceClient(
oauth2.NewClient(
context.Background(),
oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: "foobar",
TokenType: "bearer",
}),
),
srv.URL)
_, err := c.GetUser(context.Background(), connect.NewRequest(&clientsv1.GetUserRequest{}))

// Success cases are connect.CodeUnimplemented
require.Error(t, err)

var connectErr *connect.Error
if errors.As(err, &connectErr) {
if connectErr.Code() == connect.CodeUnimplemented {
tc.wantError.Equal(t, nil) // should not expect an error
} else {
tc.wantError.Equal(t, err.Error())
}
} else {
t.Errorf("error %q is not *connect.Error", err.Error())
}

tc.wantLogs.Equal(t, exportLogs().Messages())
})
}
}
47 changes: 47 additions & 0 deletions auth/clientcredentials/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package clientcredentials

import (
"context"
"strings"
"time"

"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"
)

type contextKey int

const (
clientInfoKey contextKey = iota
)

type ClientInfo struct {
ClientID string
TokenExpiresAt time.Time
TokenScopes scopes.Scopes
}

// LogFields represents a standard log representation of a client, for use in
// propagting in loggers for auditing purposes. It is safe to use on a nil
// *ClientInfo.
func (c *ClientInfo) LogFields() []log.Field {
if c == nil {
return []log.Field{log.Stringp("client", nil)}
}
return []log.Field{
log.String("client.clientID", c.ClientID),
log.Time("client.tokenExpiresAt", c.TokenExpiresAt),
log.String("client.tokenScopes", strings.Join(scopes.ToStrings(c.TokenScopes), " ")),
}
}

// ClientInfoFromContext returns client info from the given context. This is
// generally set by clientcredentials.Interceptor.
func ClientInfoFromContext(ctx context.Context) *ClientInfo {
return ctx.Value(clientInfoKey).(*ClientInfo)
}

// WithClientInfo returns a new context with the given client info.
func WithClientInfo(ctx context.Context, info *ClientInfo) context.Context {
return context.WithValue(ctx, clientInfoKey, info)
}
Loading

0 comments on commit 35e9505

Please sign in to comment.