From 0fd4b67ca7018e618b68f046bcd4edd78c555a2f Mon Sep 17 00:00:00 2001 From: xieyuschen Date: Fri, 2 Aug 2024 17:53:43 +0800 Subject: [PATCH] http2/h2c: allow to disable http upgrade in h2c --- http2/h2c/h2c.go | 25 ++++++++++++++-- http2/h2c/h2c_test.go | 66 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 2 deletions(-) diff --git a/http2/h2c/h2c.go b/http2/h2c/h2c.go index 2d6bf861b..47a2465de 100644 --- a/http2/h2c/h2c.go +++ b/http2/h2c/h2c.go @@ -37,6 +37,20 @@ func init() { } } +// Option is used to specify the behavior for the h2c handler +type Option = func(*option) + +type option struct { + disableUpgrade bool +} + +// DisableH2CUpgrade disable the Upgrade mechanism mentioned in RFC7203 topic 6.7 by the h2c server. +func DisableH2CUpgrade() Option { + return func(o *option) { + o.disableUpgrade = true + } +} + // h2cHandler is a Handler which implements h2c by hijacking the HTTP/1 traffic // that should be h2c traffic. There are two ways to begin a h2c connection // (RFC 7540 Section 3.2 and 3.4): (1) Starting with Prior Knowledge - this @@ -48,6 +62,7 @@ func init() { type h2cHandler struct { Handler http.Handler s *http2.Server + opt option } // NewHandler returns an http.Handler that wraps h, intercepting any h2c @@ -63,10 +78,16 @@ type h2cHandler struct { // The first request on an h2c connection is read entirely into memory before // the Handler is called. To limit the memory consumed by this request, wrap // the result of NewHandler in an http.MaxBytesHandler. -func NewHandler(h http.Handler, s *http2.Server) http.Handler { +func NewHandler(h http.Handler, s *http2.Server, opts ...Option) http.Handler { + var o option + for _, opt := range opts { + opt(&o) + } + return &h2cHandler{ Handler: h, s: s, + opt: o, } } @@ -103,7 +124,7 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } // Handle Upgrade to h2c (RFC 7540 Section 3.2) - if isH2CUpgrade(r.Header) { + if !s.opt.disableUpgrade && isH2CUpgrade(r.Header) { conn, settings, err := h2cUpgrade(w, r) if err != nil { if http2VerboseLogs { diff --git a/http2/h2c/h2c_test.go b/http2/h2c/h2c_test.go index 3e78f2913..76a19f0bd 100644 --- a/http2/h2c/h2c_test.go +++ b/http2/h2c/h2c_test.go @@ -169,3 +169,69 @@ func TestMaxBytesHandler(t *testing.T) { t.Errorf("resp.StatusCode = %v, want %v", got, want) } } + +func TestH2CProtocolSwitch(t *testing.T) { + const bodyLimit = 10 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + }) + + h2s := &http2.Server{} + h1s := httptest.NewUnstartedServer(http.MaxBytesHandler(NewHandler(handler, h2s), bodyLimit)) + h1s.Start() + defer h1s.Close() + + req, err := http.NewRequest("POST", h1s.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Http2-Settings", "") + req.Header.Set("Upgrade", "h2c") + req.Header.Set("Connection", "Upgrade, HTTP2-Settings") + + resp, err := h1s.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if got, want := resp.StatusCode, http.StatusSwitchingProtocols; got != want { + t.Errorf("resp.StatusCode = %v, want %v", got, want) + } +} + +func TestDisableH2CUpgrade(t *testing.T) { + const bodyLimit = 10 + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + }) + + h2s := &http2.Server{} + h1s := httptest.NewUnstartedServer(http.MaxBytesHandler(NewHandler(handler, h2s, DisableH2CUpgrade()), bodyLimit)) + h1s.Start() + defer h1s.Close() + + req, err := http.NewRequest("POST", h1s.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Http2-Settings", "") + req.Header.Set("Upgrade", "h2c") + req.Header.Set("Connection", "Upgrade, HTTP2-Settings") + + resp, err := h1s.Client().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("resp.StatusCode = %v, want %v", got, want) + } +}