diff --git a/route.go b/route.go index d81b5bf..d68bc9f 100644 --- a/route.go +++ b/route.go @@ -64,10 +64,11 @@ func RequestVars(r *http.Request) http.Header { // and then start serving requests. Using them outside of this use case is // unsupported. type Router struct { - Handle404 http.Handler - Handle405 http.Handler - prefix string - trie *trie + Handle404 http.Handler + Handle405 http.Handler + prefix string + trie *trie + middleware []func(http.Handler) http.Handler } // get404 returns the http.Handler `router` should use when serving a 404 page @@ -103,6 +104,8 @@ type route struct { params map[string][]string // the methods this endpoint can serve methods []string + // middleware to use when serving the handler on this route + middleware []func(http.Handler) http.Handler } // route uses the pieces of the request URL and the method of the request to @@ -173,11 +176,13 @@ func pickNode(nodes []*node, pieces []string, method string) *node { // of pieces. A higher score is a better match. // // paths that have a 1:1 match between pieces and nodes should score higher -// - this should be taken care of by having more nodes to score +// - this should be taken care of by having more nodes to score +// // nodes that are dynamic should score lower than static matches // nodes that are prefixes should score lower than static matches // nodes that are prefixes should score lower than nodes that are dynamic -// - this should be taken care of by having more nodes to score +// - this should be taken care of by having more nodes to score +// // nodes earlier in the path should be worth more than nodes later in the path func scoreNode(node *node, pieces []string, power int) float64 { var score float64 @@ -238,15 +243,25 @@ func (router Router) getHandler(r *http.Request) http.Handler { return router.get405() } + // apply any middleware on the route + handler := route.handler + for i := len(route.middleware) - 1; i >= 0; i-- { + handler = route.middleware[i](handler) + } + // after all that, if we still haven't found a problem, use the handler // we have - return route.handler + return handler } // ServeHTTP finds the best handler for the request, using the 404 or 405 // handlers if necessary, and serves the request. func (router Router) ServeHTTP(w http.ResponseWriter, r *http.Request) { - router.getHandler(r).ServeHTTP(w, r) + handler := router.getHandler(r) + for i := len(router.middleware) - 1; i >= 0; i-- { + handler = router.middleware[i](handler) + } + handler.ServeHTTP(w, r) } // SetPrefix sets a string prefix for the Router that won't be taken into @@ -261,6 +276,17 @@ func (router *Router) SetPrefix(prefix string) { router.prefix = prefix } +// SetMiddleware sets one or more middleware functions that will wrap all +// handlers defined on the router. Middleware will run after routing, but +// before any route-specific middleware or the route handler. +// +// Middleware is applied in the order it appears in the SetMiddleware call. So, +// for example, if router.SetMiddleware(A, B, C) is called, trout will call +// A(B(C(handler))) for any handler defined on the router. +func (router *Router) SetMiddleware(mw ...func(http.Handler) http.Handler) { + router.middleware = mw +} + // Endpoint defines a single URL template that requests can be matched against. // It is only valid to instantiate an Endpoint by calling `Router.Endpoint`. // Endpoints, on their own, are only useful for calling their methods, as they @@ -320,6 +346,20 @@ func (e *Endpoint) Handler(h http.Handler) { (*node)(e).methods[catchAllMethod] = h } +// Middleware sets one or more middleware functions that will wrap the default +// http.Handler for `e`, to be used for all requests that `e` matches that +// don't match a method explicitly set for `e` using the Methods method. +// Middleware will run after routing, after any Router middleware, but before +// the route handler. +// +// Middleware is applied in the order it appears in the Middleware call. So, +// for example, if Endpoint.SetMiddleware(A, B, C) is called, trout will call +// A(B(C(handler))) when calling the Endpoint's handler. +func (e *Endpoint) Middleware(mw ...func(http.Handler) http.Handler) *Endpoint { + (*node)(e).middleware[catchAllMethod] = mw + return e +} + // Prefix defines a URL template that requests can be matched against. It is // only valid to instantiate a prefix by calling `Router.Prefix`. Prefixes, on // their own, are only useful for calling their methods, as they don't do @@ -369,6 +409,20 @@ func (p *Prefix) Handler(h http.Handler) { (*node)(p).methods[catchAllMethod] = h } +// Middleware sets one or more middleware functions that will wrap the default +// http.Handler for `p`, to be used for all requests that `p` matches that +// don't match a method explicitly set for `e` using the Methods method. +// Middleware will run after routing, after any Router middleware, but before +// the route handler. +// +// Middleware is applied in the order it appears in the Middleware call. So, +// for example, if Prefix.SetMiddleware(A, B, C) is called, trout will call +// A(B(C(handler))) when calling the Endpoint's handler. +func (p *Prefix) Middleware(mw ...func(http.Handler) http.Handler) *Prefix { + (*node)(p).middleware[catchAllMethod] = mw + return p +} + // Methods defines a pairing of an Endpoint to HTTP request methods, to map // designate specific http.Handlers for requests matching that Endpoint made // using the specified methods. It is only valid to instantiate Methods by @@ -413,3 +467,19 @@ func (m Methods) Handler(h http.Handler) { m.n.methods[method] = h } } + +// Middleware sets one or more middleware functions that will wrap the +// http.Handler associated with `m`, to be used whenever a request that matches +// the Endpoint also matches one of the Methods associated with m. Middleware +// will run after routing, after any Router middleware, but before the route +// handler. +// +// Middleware is applied in the order it appears in the Middleware call. So, +// for example, if Methods.SetMiddleware(A, B, C) is called, trout will call +// A(B(C(handler))) when calling the Methods' handler. +func (m Methods) Middleware(mw ...func(http.Handler) http.Handler) Methods { + for _, method := range m.m { + m.n.middleware[method] = mw + } + return m +} diff --git a/trie.go b/trie.go index 517e07b..b1a87c1 100644 --- a/trie.go +++ b/trie.go @@ -74,18 +74,20 @@ type node struct { children map[string]*node wildChildren []*node methods map[string]http.Handler + middleware map[string][]func(http.Handler) http.Handler } // newChild inserts a new child node under `n` and // returns the child. func (n *node) newChild(value key, term bool) *node { newNode := &node{ - value: value, - term: term, - depth: n.depth + 1, - children: map[string]*node{}, - methods: map[string]http.Handler{}, - parent: n, + value: value, + term: term, + depth: n.depth + 1, + children: map[string]*node{}, + methods: map[string]http.Handler{}, + middleware: map[string][]func(http.Handler) http.Handler{}, + parent: n, } if value.dynamic { n.wildChildren = append(n.wildChildren, newNode)