Skip to content

Commit

Permalink
mt-broker-filter: reject request for wrong audience
Browse files Browse the repository at this point in the history
Signed-off-by: pingjiang <[email protected]>
  • Loading branch information
xiangpingjiang committed Nov 5, 2023
1 parent a7628df commit 13bd9a5
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
5 changes: 4 additions & 1 deletion cmd/broker/filter/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"knative.dev/eventing/pkg/auth"
"knative.dev/eventing/pkg/broker/filter"
triggerinformer "knative.dev/eventing/pkg/client/injection/informers/eventing/v1/trigger"
subscriptioninformer "knative.dev/eventing/pkg/client/injection/informers/messaging/v1/subscription"
"knative.dev/eventing/pkg/reconciler/names"
)

Expand Down Expand Up @@ -77,6 +78,7 @@ func main() {
log.Printf("Registering %d informers", len(injection.Default.GetInformers()))

ctx, informers := injection.Default.SetupInformers(ctx, cfg)
ctx = injection.WithConfig(ctx, cfg)
kubeClient := kubeclient.Get(ctx)

loggingConfig, err := broker.GetLoggingConfig(ctx, system.Namespace(), logging.ConfigMapName())
Expand Down Expand Up @@ -123,7 +125,8 @@ func main() {
oidcTokenProvider := auth.NewOIDCTokenProvider(ctx)
// We are running both the receiver (takes messages in from the Broker) and the dispatcher (send
// the messages to the triggers' subscribers) in this binary.
handler, err := filter.NewHandler(logger, oidcTokenProvider, triggerinformer.Get(ctx), reporter, ctxFunc)
oidcTokenVerifier := auth.NewOIDCTokenVerifier(ctx)
handler, err := filter.NewHandler(logger, oidcTokenVerifier, oidcTokenProvider, triggerinformer.Get(ctx), subscriptioninformer.Get(ctx), reporter, ctxFunc)
if err != nil {
logger.Fatal("Error creating Handler", zap.Error(err))
}
Expand Down
70 changes: 59 additions & 11 deletions pkg/broker/filter/filter_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,19 @@ import (

eventingv1 "knative.dev/eventing/pkg/apis/eventing/v1"
"knative.dev/eventing/pkg/apis/feature"
messagingv1 "knative.dev/eventing/pkg/apis/messaging/v1"
"knative.dev/eventing/pkg/broker"
v1 "knative.dev/eventing/pkg/client/informers/externalversions/eventing/v1"
messaginginformersv1 "knative.dev/eventing/pkg/client/informers/externalversions/messaging/v1"
eventinglisters "knative.dev/eventing/pkg/client/listers/eventing/v1"
messaginglisters "knative.dev/eventing/pkg/client/listers/messaging/v1"
"knative.dev/eventing/pkg/eventfilter"
"knative.dev/eventing/pkg/eventfilter/attributes"
"knative.dev/eventing/pkg/eventfilter/subscriptionsapi"
"knative.dev/eventing/pkg/kncloudevents"
"knative.dev/eventing/pkg/reconciler/sugar/trigger/path"
"knative.dev/eventing/pkg/tracing"
"knative.dev/pkg/kmeta"
)

const (
Expand All @@ -72,14 +76,16 @@ type Handler struct {

eventDispatcher *kncloudevents.Dispatcher

triggerLister eventinglisters.TriggerLister
logger *zap.Logger
withContext func(ctx context.Context) context.Context
filtersMap *subscriptionsapi.FiltersMap
triggerLister eventinglisters.TriggerLister
logger *zap.Logger
withContext func(ctx context.Context) context.Context
filtersMap *subscriptionsapi.FiltersMap
tokenVerifier *auth.OIDCTokenVerifier
subscriptionLister messaginglisters.SubscriptionLister
}

// NewHandler creates a new Handler and its associated EventReceiver.
func NewHandler(logger *zap.Logger, oidcTokenProvider *auth.OIDCTokenProvider, triggerInformer v1.TriggerInformer, reporter StatsReporter, wc func(ctx context.Context) context.Context) (*Handler, error) {
func NewHandler(logger *zap.Logger, tokenVerifier *auth.OIDCTokenVerifier, oidcTokenProvider *auth.OIDCTokenProvider, triggerInformer v1.TriggerInformer, subscriptionInformer messaginginformersv1.SubscriptionInformer, reporter StatsReporter, wc func(ctx context.Context) context.Context) (*Handler, error) {
kncloudevents.ConfigureConnectionArgs(&kncloudevents.ConnectionArgs{
MaxIdleConns: defaultMaxIdleConnections,
MaxIdleConnsPerHost: defaultMaxIdleConnectionsPerHost,
Expand Down Expand Up @@ -127,12 +133,14 @@ func NewHandler(logger *zap.Logger, oidcTokenProvider *auth.OIDCTokenProvider, t
})

return &Handler{
reporter: reporter,
eventDispatcher: kncloudevents.NewDispatcher(oidcTokenProvider),
triggerLister: triggerInformer.Lister(),
logger: logger,
withContext: wc,
filtersMap: fm,
reporter: reporter,
eventDispatcher: kncloudevents.NewDispatcher(oidcTokenProvider),
triggerLister: triggerInformer.Lister(),
subscriptionLister: subscriptionInformer.Lister(),
logger: logger,
tokenVerifier: tokenVerifier,
withContext: wc,
filtersMap: fm,
}, nil
}

Expand Down Expand Up @@ -203,6 +211,38 @@ func (h *Handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
return
}

features := feature.FromContext(ctx)
if features.IsOIDCAuthentication() {
h.logger.Debug("OIDC authentication is enabled")
subs, err := h.getSubscription(t)
if err != nil {
h.logger.Info("Unable to get the subscription", zap.Error(err), zap.Any("triggerRef", triggerRef))
writer.WriteHeader(http.StatusBadRequest)
return
}

if subs.Spec.Subscriber.Audience == nil {
h.logger.Warn(fmt.Sprintf("Audience of subscription %s/%s must not be nil, while feature %s is enabled", subs.Name, subs.Namespace, feature.OIDCAuthentication))
writer.WriteHeader(http.StatusInternalServerError)
return
}

token := auth.GetJWTFromHeader(request.Header)
if token == "" {
h.logger.Warn(fmt.Sprintf("No JWT in %s header provided while feature %s is enabled", auth.AuthHeaderKey, feature.OIDCAuthentication))
writer.WriteHeader(http.StatusUnauthorized)
return
}

if _, err := h.tokenVerifier.VerifyJWT(ctx, token, *subs.Spec.Subscriber.Audience); err != nil {
h.logger.Warn("no valid JWT provided", zap.Error(err))
writer.WriteHeader(http.StatusUnauthorized)
return
}

h.logger.Debug("Request contained a valid JWT. Continuing...")
}

reportArgs := &ReportArgs{
ns: t.Namespace,
trigger: t.Name,
Expand Down Expand Up @@ -367,6 +407,14 @@ func (h *Handler) getTrigger(ref path.NamespacedNameUID) (*eventingv1.Trigger, e
return t, nil
}

func (h *Handler) getSubscription(t *eventingv1.Trigger) (*messagingv1.Subscription, error) {
sub, err := h.subscriptionLister.Subscriptions(t.Namespace).Get(kmeta.ChildName(fmt.Sprintf("%s-%s-", t.Spec.Broker, t.Name), string(t.GetUID())))
if err != nil {
return nil, err
}
return sub, nil
}

func (h *Handler) filterEvent(ctx context.Context, trigger *eventingv1.Trigger, event cloudevents.Event) eventfilter.FilterResult {
switch {
case feature.FromContext(ctx).IsEnabled(feature.NewTriggerFilters) && len(trigger.Spec.Filters) > 0:
Expand Down
8 changes: 8 additions & 0 deletions pkg/broker/filter/filter_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ import (

triggerinformerfake "knative.dev/eventing/pkg/client/injection/informers/eventing/v1/trigger/fake"

subscriptioninformerfake "knative.dev/eventing/pkg/client/injection/informers/messaging/v1/subscription/fake"

// Fake injection client
_ "knative.dev/pkg/client/injection/kube/client/fake"
)
Expand Down Expand Up @@ -431,6 +433,7 @@ func TestReceiver(t *testing.T) {

logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.AddCaller()))
oidcTokenProvider := auth.NewOIDCTokenProvider(ctx)
oidcTokenVerifier := auth.NewOIDCTokenVerifier(ctx)

// Replace the SubscriberURI to point at our fake server.
for _, trig := range tc.triggers {
Expand All @@ -447,8 +450,10 @@ func TestReceiver(t *testing.T) {
reporter := &mockReporter{}
r, err := NewHandler(
logger,
oidcTokenVerifier,
oidcTokenProvider,
triggerinformerfake.Get(ctx),
subscriptioninformerfake.Get(ctx),
reporter,
func(ctx context.Context) context.Context {
return ctx
Expand Down Expand Up @@ -616,6 +621,7 @@ func TestReceiver_WithSubscriptionsAPI(t *testing.T) {

logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.AddCaller()))
oidcTokenProvider := auth.NewOIDCTokenProvider(ctx)
oidcTokenVerifier := auth.NewOIDCTokenVerifier(ctx)

// Replace the SubscriberURI to point at our fake server.
for _, trig := range tc.triggers {
Expand All @@ -633,8 +639,10 @@ func TestReceiver_WithSubscriptionsAPI(t *testing.T) {
reporter := &mockReporter{}
r, err := NewHandler(
logger,
oidcTokenVerifier,
oidcTokenProvider,
triggerinformerfake.Get(ctx),
subscriptioninformerfake.Get(ctx),
reporter,
func(ctx context.Context) context.Context {
return feature.ToContext(context.TODO(), feature.Flags{
Expand Down

0 comments on commit 13bd9a5

Please sign in to comment.