diff --git a/internal/server/authz.go b/internal/server/authz.go index a59c839..a9cae03 100644 --- a/internal/server/authz.go +++ b/internal/server/authz.go @@ -80,7 +80,6 @@ func (e *ExtAuthZFilter) Register(server *grpc.Server) { // Check is the implementation of the Envoy AuthorizationServer interface. func (e *ExtAuthZFilter) Check(ctx context.Context, req *envoy.CheckRequest) (response *envoy.CheckResponse, err error) { - ctx = propagateRequestID(ctx, req) // Push the original request id tot eh context to include it in all logs log := e.log.Context(ctx) // If there are no trigger rules, allow the request with no check executions. @@ -118,7 +117,6 @@ func (e *ExtAuthZFilter) Check(ctx context.Context, req *envoy.CheckRequest) (re case *configv1.Filter_Mock: h = authz.NewMockHandler(ft.Mock) case *configv1.Filter_Oidc: - // TODO(nacx): Check if the Oidc setting is enough or we have to pull the default Oidc settings if h, err = authz.NewOIDCHandler(ft.Oidc, e.jwks, e.sessions, oidc.Clock{}, oidc.NewRandomGenerator()); err != nil { return nil, err } @@ -148,15 +146,6 @@ func (e *ExtAuthZFilter) Check(ctx context.Context, req *envoy.CheckRequest) (re return deny(codes.PermissionDenied, "no chains matched"), nil } -// propagateRequestID propagates the request id from the request headers to the context. -func propagateRequestID(ctx context.Context, req *envoy.CheckRequest) context.Context { - headers := req.GetAttributes().GetRequest().GetHttp().GetHeaders() - if headers == nil || headers[EnvoyXRequestID] == "" { - return ctx - } - return telemetry.KeyValuesToContext(ctx, EnvoyXRequestID, headers[EnvoyXRequestID]) -} - // matches returns true if the given request matches the given match configuration func matches(m *configv1.Match, req *envoy.CheckRequest) bool { if m == nil { diff --git a/internal/server/authz_test.go b/internal/server/authz_test.go index 49684fe..17120b0 100644 --- a/internal/server/authz_test.go +++ b/internal/server/authz_test.go @@ -130,7 +130,7 @@ func TestGrpcNoChainsMatched(t *testing.T) { require.NoError(t, err) client := envoy.NewAuthorizationClient(conn) - ok, err := client.Check(context.Background(), &envoy.CheckRequest{}) + ok, err := client.Check(context.Background(), header("test")) require.NoError(t, err) require.Equal(t, int32(codes.PermissionDenied), ok.Status.Code) } @@ -325,6 +325,7 @@ func header(value string) *envoy.CheckRequest { Request: &envoy.AttributeContext_Request{ Http: &envoy.AttributeContext_HttpRequest{ Headers: map[string]string{ + "x-request-id": "test-request-id", "x-test-headers": value, }, }, diff --git a/internal/server/requestid.go b/internal/server/requestid.go new file mode 100644 index 0000000..8ddfb39 --- /dev/null +++ b/internal/server/requestid.go @@ -0,0 +1,45 @@ +// Copyright 2024 Tetrate +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "context" + + envoy "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" + "github.com/tetratelabs/telemetry" + "google.golang.org/grpc" +) + +// PropagateRequestID is a gRPC middleware that propagates the request id from an Envoy CheckRequest +// to the logging context. +func PropagateRequestID( + ctx context.Context, + req interface{}, + _ *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, +) (interface{}, error) { + check, ok := req.(*envoy.CheckRequest) + if !ok { + return handler(ctx, req) + } + + headers := check.GetAttributes().GetRequest().GetHttp().GetHeaders() + if headers == nil || headers[EnvoyXRequestID] == "" { + return handler(ctx, req) + } + + ctx = telemetry.KeyValuesToContext(ctx, EnvoyXRequestID, headers[EnvoyXRequestID]) + return handler(ctx, req) +} diff --git a/internal/server/server.go b/internal/server/server.go index a132597..ae0fc10 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -74,7 +74,10 @@ func (s *Server) PreRun() error { // Initialize the gRPC server s.server = grpc.NewServer( // TODO(nacx): Expose the right flags for secure connections - grpc.ChainUnaryInterceptor(logMiddleware.UnaryServerInterceptor), + grpc.ChainUnaryInterceptor( + PropagateRequestID, + logMiddleware.UnaryServerInterceptor, + ), ) for _, h := range s.registerHandlers {