diff --git a/context.go b/context.go index 21b1384e..d9017b48 100644 --- a/context.go +++ b/context.go @@ -34,10 +34,23 @@ func NewRouteContext() *Context { return &Context{} } -var ( - // RouteCtxKey is the context.Context key to store the request context. - RouteCtxKey = &contextKey{"RouteContext"} -) +// WithRouteContext returns the list of methods allowed for the current +// request, based on the current routing context. +func AllowedMethods(ctx context.Context) []string { + if rctx := RouteContext(ctx); rctx != nil { + result := make([]string, 0, len(rctx.methodsAllowed)) + for _, method := range rctx.methodsAllowed { + if method := methodTypString(method); method != "" { + result = append(result, method) + } + } + return result + } + return nil +} + +// RouteCtxKey is the context.Context key to store the request context. +var RouteCtxKey = &contextKey{"RouteContext"} // Context is the default routing context set on the root node of a // request context to track route patterns, URL parameters and diff --git a/context_test.go b/context_test.go index 4731c709..9261fda8 100644 --- a/context_test.go +++ b/context_test.go @@ -1,6 +1,10 @@ package chi -import "testing" +import ( + "context" + "strings" + "testing" +) // TestRoutePattern tests correct in-the-middle wildcard removals. // If user organizes a router like this: @@ -85,3 +89,32 @@ func TestRoutePattern(t *testing.T) { t.Fatal("unexpected route pattern for root: " + p) } } + +func TestAllowedMethods(t *testing.T) { + t.Run("no chi context", func(t *testing.T) { + got := AllowedMethods(context.Background()) + if got != nil { + t.Errorf("Unexpected allowed methods: %v", got) + } + }) + t.Run("expected methods", func(t *testing.T) { + want := "GET HEAD" + ctx := context.WithValue(context.Background(), RouteCtxKey, &Context{ + methodsAllowed: []methodTyp{mGET, mHEAD}, + }) + got := strings.Join(AllowedMethods(ctx), " ") + if want != got { + t.Errorf("Unexpected allowed methods: %s, want: %s", got, want) + } + }) + t.Run("unexpected methods", func(t *testing.T) { + want := "GET HEAD" + ctx := context.WithValue(context.Background(), RouteCtxKey, &Context{ + methodsAllowed: []methodTyp{mGET, mHEAD, 9000}, + }) + got := strings.Join(AllowedMethods(ctx), " ") + if want != got { + t.Errorf("Unexpected allowed methods: %s, want: %s", got, want) + } + }) +}