-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
auth/clientcredentials: add schema-based scope enforcement interceptor (
#61) Adds an interceptor, `clientcredentials.Interceptor`, that enforces required SAMS scopes based on a proto schema extension. This will reduce boilerplate in RPC implementations and reduce the chance of mistakes/out-of-sync problems between a schema and the implementation Usage example: ```proto extend google.protobuf.MethodOptions { // The SAMS scopes required to use this RPC. // // The range 50000-99999 is reserved for internal use within individual organizations // so you can use numbers in this range freely for in-house applications. repeated string sams_required_scopes = 50001; } ``` Allows you to set required scopes as method options: ```proto rpc GetUserRoles(GetUserRolesRequest) returns (GetUserRolesResponse) { option (sams_required_scopes) = "sams::user.roles::read"; }; ``` This generates `E_SamsRequiredScopes` that can be used to point to where we can extract `sams_required_scopes`. ## Test plan Unit tests --------- Co-authored-by: Joe Chen <[email protected]>
- Loading branch information
Showing
5 changed files
with
594 additions
and
145 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
package clientcredentials | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"strings" | ||
|
||
"connectrpc.com/connect" | ||
"go.opentelemetry.io/otel" | ||
"go.opentelemetry.io/otel/attribute" | ||
otelcodes "go.opentelemetry.io/otel/codes" | ||
"go.opentelemetry.io/otel/trace" | ||
"google.golang.org/protobuf/reflect/protoreflect" | ||
"google.golang.org/protobuf/runtime/protoimpl" | ||
|
||
"github.com/sourcegraph/log" | ||
sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go" | ||
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes" | ||
"github.com/sourcegraph/sourcegraph/lib/errors" | ||
) | ||
|
||
var tracer = otel.Tracer("sams/auth/clientcredentials") | ||
|
||
type TokenIntrospector interface { | ||
// IntrospectToken takes a SAMS access token and returns relevant metadata. | ||
// This is generally implemented by *sams.TokensServiceV1. | ||
// | ||
// 🚨 SECURITY: SAMS will return a successful result if the token is valid, but | ||
// is no longer active. It is critical that the caller not honor tokens where | ||
// `.Active == false`. | ||
IntrospectToken(ctx context.Context, token string) (*sams.IntrospectTokenResponse, error) | ||
} | ||
|
||
// See clientcredentials.NewInterceptor. | ||
type Interceptor struct { | ||
logger log.Logger | ||
introspector TokenIntrospector | ||
extension *protoimpl.ExtensionInfo | ||
} | ||
|
||
// NewInterceptor creates a serverside handler interceptor that ensures every | ||
// incoming request has a valid client credential token with the required scopes | ||
// indicated in the RPC method options. When used, required scopes CANNOT be | ||
// empty - if no scopes are required, declare a separate service that does not | ||
// use this interceptor. | ||
// | ||
// To declare required SAMS scopes in your RPC, add the following to your proto | ||
// schema: | ||
// | ||
// extend google.protobuf.MethodOptions { | ||
// // The SAMS scopes required to use this RPC. | ||
// // | ||
// // The range 50000-99999 is reserved for internal use within individual organizations | ||
// // so you can use numbers in this range freely for in-house applications. | ||
// repeated string sams_required_scopes = 50001; | ||
// } | ||
// | ||
// In your RPCs, add the `(sams_required_scopes)` option as a comma-delimited | ||
// list: | ||
// | ||
// rpc GetUserRoles(GetUserRolesRequest) returns (GetUserRolesResponse) { | ||
// option (sams_required_scopes) = "sams::user.roles::read"; | ||
// }; | ||
// | ||
// This will generate a variable called `E_SamsRequiredScopes` in your generated | ||
// proto bindings. This variable should be provided to NewInterceptor to allow | ||
// it to identify where to source the required scopes from. | ||
// | ||
// The provided logger is used to record internal-server errors. | ||
func NewInterceptor( | ||
logger log.Logger, | ||
introspector TokenIntrospector, | ||
methodOptionsRequiredScopesExtension *protoimpl.ExtensionInfo, | ||
) *Interceptor { | ||
return &Interceptor{ | ||
logger: logger.Scoped("clientcredentials"), | ||
introspector: introspector, | ||
extension: methodOptionsRequiredScopesExtension, | ||
} | ||
} | ||
|
||
var _ connect.Interceptor = (*Interceptor)(nil) | ||
|
||
func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { | ||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { | ||
if req.Spec().IsClient { | ||
return next(ctx, req) // no-op for clients | ||
} | ||
requiredScopes, err := extractSchemaRequiredScopes(req.Spec(), i.extension) | ||
if err != nil { | ||
return nil, internalError(ctx, i.logger, err, "internal schema error") // invalid schema is internal error | ||
} | ||
info, err := i.requireScope(ctx, req.Header(), requiredScopes) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return next(WithClientInfo(ctx, info), req) | ||
} | ||
} | ||
|
||
func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { | ||
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { | ||
return next(ctx, spec) // no-op for clients | ||
} | ||
} | ||
|
||
func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { | ||
return func(ctx context.Context, conn connect.StreamingHandlerConn) error { | ||
if conn.Spec().IsClient { | ||
return next(ctx, conn) // no-op for clients | ||
} | ||
requiredScopes, err := extractSchemaRequiredScopes(conn.Spec(), i.extension) | ||
if err != nil { | ||
return internalError(ctx, i.logger, err, "internal schema error") // invalid schema is internal error | ||
} | ||
info, err := i.requireScope(ctx, conn.RequestHeader(), requiredScopes) | ||
if err != nil { | ||
return err | ||
} | ||
return next(WithClientInfo(ctx, info), conn) | ||
} | ||
} | ||
|
||
// RequireScope ensures the request context has a valid SAMS M2M token | ||
// with requiredScope. It returns a ConnectRPC status error suitable to be | ||
// returned directly from a ConnectRPC implementation. | ||
func (i *Interceptor) requireScope(ctx context.Context, headers http.Header, requiredScopes scopes.Scopes) (_ *ClientInfo, err error) { | ||
var span trace.Span | ||
ctx, span = tracer.Start(ctx, "clientcredentials.requireScope") | ||
defer func() { | ||
if err != nil { | ||
span.RecordError(err) | ||
span.SetStatus(otelcodes.Error, "check failed") | ||
} | ||
span.End() | ||
}() | ||
|
||
token, err := extractBearerContents(headers) | ||
if err != nil { | ||
return nil, connect.NewError(connect.CodeUnauthenticated, | ||
errors.Wrap(err, "invalid authorization header")) | ||
} | ||
|
||
result, err := i.introspector.IntrospectToken(ctx, token) | ||
if err != nil { | ||
return nil, internalError(ctx, i.logger, err, "unable to validate token") | ||
} | ||
span.SetAttributes( | ||
attribute.String("client_id", result.ClientID), | ||
attribute.String("token_expires_at", result.ExpiresAt.String()), | ||
attribute.StringSlice("token_scopes", scopes.ToStrings(result.Scopes))) | ||
info := &ClientInfo{ | ||
ClientID: result.ClientID, | ||
TokenExpiresAt: result.ExpiresAt, | ||
TokenScopes: result.Scopes, | ||
} | ||
|
||
// Active encapsulates whether the token is active, including expiration. | ||
if !result.Active { | ||
// Record detailed error in span, and return an opaque one | ||
span.SetAttributes(attribute.String("full_error", "inactive token")) | ||
return info, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied")) | ||
} | ||
|
||
// Check for our required scope. | ||
for _, required := range requiredScopes { | ||
if !result.Scopes.Match(required) { | ||
err = errors.Newf("got scopes %+v, required: %+v", result.Scopes, requiredScopes) | ||
span.SetAttributes(attribute.String("full_error", err.Error())) | ||
return info, connect.NewError(connect.CodePermissionDenied, | ||
errors.Wrap(err, "insufficient scopes")) | ||
} | ||
} | ||
|
||
return info, nil | ||
} | ||
|
||
func extractSchemaRequiredScopes(spec connect.Spec, extension *protoimpl.ExtensionInfo) (scopes.Scopes, error) { | ||
method, ok := spec.Schema.(protoreflect.MethodDescriptor) | ||
if !ok { | ||
return nil, errors.Newf("expected protoreflect.MethodDescriptor, got %T", spec.Schema) | ||
} | ||
|
||
value := method.Options().ProtoReflect().Get(extension.TypeDescriptor()) | ||
if !value.IsValid() { | ||
return nil, errors.Newf("extension field %s not valid", extension.TypeDescriptor().FullName()) | ||
} | ||
list := value.List() | ||
if list.Len() == 0 { | ||
return nil, errors.Newf("extension field %s cannot be empty", extension.TypeDescriptor().FullName()) | ||
} | ||
|
||
requiredScopes := make(scopes.Scopes, list.Len()) | ||
for i := 0; i < list.Len(); i++ { | ||
requiredScopes[i] = scopes.Scope(list.Get(i).String()) | ||
} | ||
return requiredScopes, nil | ||
} | ||
|
||
func extractBearerContents(h http.Header) (string, error) { | ||
authHeader := h.Get("Authorization") | ||
if authHeader == "" { | ||
return "", errors.New("no token provided in Authorization header") | ||
} | ||
typ := strings.SplitN(authHeader, " ", 2) | ||
if len(typ) != 2 { | ||
return "", errors.New("token type missing in Authorization header") | ||
} | ||
if !strings.EqualFold(typ[0], "bearer") { | ||
return "", errors.Newf("invalid token type %s in Authorization header", typ[0]) | ||
} | ||
return typ[1], nil | ||
} | ||
|
||
// internalError logs an error, adds it to the trace, and returns a connect | ||
// error with a safe message. | ||
func internalError(ctx context.Context, logger log.Logger, err error, safeMsg string) error { | ||
trace.SpanFromContext(ctx). | ||
SetAttributes( | ||
attribute.String("full_error", err.Error()), | ||
) | ||
logger.WithTrace(log.TraceContext{ | ||
TraceID: trace.SpanContextFromContext(ctx).TraceID().String(), | ||
SpanID: trace.SpanContextFromContext(ctx).SpanID().String(), | ||
}). | ||
AddCallerSkip(1). | ||
Error(safeMsg, | ||
log.String("code", connect.CodeInternal.String()), | ||
log.Error(err), | ||
) | ||
return connect.NewError(connect.CodeInternal, errors.New(safeMsg)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
package clientcredentials | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"connectrpc.com/connect" | ||
"github.com/hexops/autogold/v2" | ||
"github.com/stretchr/testify/require" | ||
"golang.org/x/oauth2" | ||
|
||
"github.com/sourcegraph/log/logtest" | ||
sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go" | ||
clientsv1 "github.com/sourcegraph/sourcegraph-accounts-sdk-go/clients/v1" | ||
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/clients/v1/clientsv1connect" | ||
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes" | ||
"github.com/sourcegraph/sourcegraph/lib/errors" | ||
) | ||
|
||
type mockTokenIntrospector struct { | ||
response *sams.IntrospectTokenResponse | ||
} | ||
|
||
func (m *mockTokenIntrospector) IntrospectToken(ctx context.Context, token string) (*sams.IntrospectTokenResponse, error) { | ||
return m.response, nil | ||
} | ||
|
||
func TestInterceptor(t *testing.T) { | ||
// All tests based on UsersService.GetUser() | ||
for _, tc := range []struct { | ||
name string | ||
token *sams.IntrospectTokenResponse | ||
wantError autogold.Value | ||
wantLogs autogold.Value | ||
}{{ | ||
name: "inactive token", | ||
token: &sams.IntrospectTokenResponse{ | ||
Active: false, | ||
}, | ||
wantError: autogold.Expect("permission_denied: permission denied"), | ||
wantLogs: autogold.Expect([]string{}), | ||
}, { | ||
name: "insufficient scopes", | ||
token: &sams.IntrospectTokenResponse{ | ||
Active: true, | ||
}, | ||
wantError: autogold.Expect("permission_denied: insufficient scopes: got scopes [], required: [profile]"), | ||
wantLogs: autogold.Expect([]string{}), | ||
}, { | ||
name: "matches required scope", | ||
token: &sams.IntrospectTokenResponse{ | ||
Active: true, | ||
Scopes: scopes.Scopes{"profile"}, | ||
}, | ||
wantError: autogold.Expect(nil), // should not error! | ||
wantLogs: autogold.Expect([]string{}), | ||
}, { | ||
name: "wrong scope", | ||
token: &sams.IntrospectTokenResponse{ | ||
Active: true, | ||
Scopes: scopes.Scopes{"not-a-scope"}, | ||
}, | ||
wantError: autogold.Expect("permission_denied: insufficient scopes: got scopes [not-a-scope], required: [profile]"), | ||
wantLogs: autogold.Expect([]string{}), | ||
}} { | ||
t.Run(tc.name, func(t *testing.T) { | ||
logger, exportLogs := logtest.Captured(t) | ||
interceptor := NewInterceptor( | ||
logger, | ||
&mockTokenIntrospector{ | ||
response: tc.token, | ||
}, | ||
clientsv1.E_SamsRequiredScopes, | ||
) | ||
mux := http.NewServeMux() | ||
mux.Handle( | ||
clientsv1connect.NewUsersServiceHandler(clientsv1connect.UnimplementedUsersServiceHandler{}, | ||
connect.WithInterceptors(interceptor)), | ||
) | ||
srv := httptest.NewServer(mux) | ||
c := clientsv1connect.NewUsersServiceClient( | ||
oauth2.NewClient( | ||
context.Background(), | ||
oauth2.StaticTokenSource(&oauth2.Token{ | ||
AccessToken: "foobar", | ||
TokenType: "bearer", | ||
}), | ||
), | ||
srv.URL) | ||
_, err := c.GetUser(context.Background(), connect.NewRequest(&clientsv1.GetUserRequest{})) | ||
|
||
// Success cases are connect.CodeUnimplemented | ||
require.Error(t, err) | ||
|
||
var connectErr *connect.Error | ||
if errors.As(err, &connectErr) { | ||
if connectErr.Code() == connect.CodeUnimplemented { | ||
tc.wantError.Equal(t, nil) // should not expect an error | ||
} else { | ||
tc.wantError.Equal(t, err.Error()) | ||
} | ||
} else { | ||
t.Errorf("error %q is not *connect.Error", err.Error()) | ||
} | ||
|
||
tc.wantLogs.Equal(t, exportLogs().Messages()) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package clientcredentials | ||
|
||
import ( | ||
"context" | ||
"strings" | ||
"time" | ||
|
||
"github.com/sourcegraph/log" | ||
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes" | ||
) | ||
|
||
type contextKey int | ||
|
||
const ( | ||
clientInfoKey contextKey = iota | ||
) | ||
|
||
type ClientInfo struct { | ||
ClientID string | ||
TokenExpiresAt time.Time | ||
TokenScopes scopes.Scopes | ||
} | ||
|
||
// LogFields represents a standard log representation of a client, for use in | ||
// propagting in loggers for auditing purposes. It is safe to use on a nil | ||
// *ClientInfo. | ||
func (c *ClientInfo) LogFields() []log.Field { | ||
if c == nil { | ||
return []log.Field{log.Stringp("client", nil)} | ||
} | ||
return []log.Field{ | ||
log.String("client.clientID", c.ClientID), | ||
log.Time("client.tokenExpiresAt", c.TokenExpiresAt), | ||
log.String("client.tokenScopes", strings.Join(scopes.ToStrings(c.TokenScopes), " ")), | ||
} | ||
} | ||
|
||
// ClientInfoFromContext returns client info from the given context. This is | ||
// generally set by clientcredentials.Interceptor. | ||
func ClientInfoFromContext(ctx context.Context) *ClientInfo { | ||
return ctx.Value(clientInfoKey).(*ClientInfo) | ||
} | ||
|
||
// WithClientInfo returns a new context with the given client info. | ||
func WithClientInfo(ctx context.Context, info *ClientInfo) context.Context { | ||
return context.WithValue(ctx, clientInfoKey, info) | ||
} |
Oops, something went wrong.