Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

http2/h2c: allow to disable http upgrade in h2c #219

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions http2/h2c/h2c.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ func init() {
}
}

// Option is used to specify the behavior for the h2c handler
type Option = func(*option)

type option struct {
disableUpgrade bool
}

// DisableH2CUpgrade disable the Upgrade mechanism mentioned in RFC7203 topic 6.7 by the h2c server.
func DisableH2CUpgrade() Option {
return func(o *option) {
o.disableUpgrade = true
}
}

// h2cHandler is a Handler which implements h2c by hijacking the HTTP/1 traffic
// that should be h2c traffic. There are two ways to begin a h2c connection
// (RFC 7540 Section 3.2 and 3.4): (1) Starting with Prior Knowledge - this
Expand All @@ -48,6 +62,7 @@ func init() {
type h2cHandler struct {
Handler http.Handler
s *http2.Server
opt option
}

// NewHandler returns an http.Handler that wraps h, intercepting any h2c
Expand All @@ -63,10 +78,16 @@ type h2cHandler struct {
// The first request on an h2c connection is read entirely into memory before
// the Handler is called. To limit the memory consumed by this request, wrap
// the result of NewHandler in an http.MaxBytesHandler.
func NewHandler(h http.Handler, s *http2.Server) http.Handler {
func NewHandler(h http.Handler, s *http2.Server, opts ...Option) http.Handler {
var o option
for _, opt := range opts {
opt(&o)
}

return &h2cHandler{
Handler: h,
s: s,
opt: o,
}
}

Expand Down Expand Up @@ -103,7 +124,7 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Handle Upgrade to h2c (RFC 7540 Section 3.2)
if isH2CUpgrade(r.Header) {
if !s.opt.disableUpgrade && isH2CUpgrade(r.Header) {
conn, settings, err := h2cUpgrade(w, r)
if err != nil {
if http2VerboseLogs {
Expand Down
66 changes: 66 additions & 0 deletions http2/h2c/h2c_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,69 @@ func TestMaxBytesHandler(t *testing.T) {
t.Errorf("resp.StatusCode = %v, want %v", got, want)
}
}

func TestH2CProtocolSwitch(t *testing.T) {
const bodyLimit = 10
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

})

h2s := &http2.Server{}
h1s := httptest.NewUnstartedServer(http.MaxBytesHandler(NewHandler(handler, h2s), bodyLimit))
h1s.Start()
defer h1s.Close()

req, err := http.NewRequest("POST", h1s.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Http2-Settings", "")
req.Header.Set("Upgrade", "h2c")
req.Header.Set("Connection", "Upgrade, HTTP2-Settings")

resp, err := h1s.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if got, want := resp.StatusCode, http.StatusSwitchingProtocols; got != want {
t.Errorf("resp.StatusCode = %v, want %v", got, want)
}
}

func TestDisableH2CUpgrade(t *testing.T) {
const bodyLimit = 10
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

})

h2s := &http2.Server{}
h1s := httptest.NewUnstartedServer(http.MaxBytesHandler(NewHandler(handler, h2s, DisableH2CUpgrade()), bodyLimit))
h1s.Start()
defer h1s.Close()

req, err := http.NewRequest("POST", h1s.URL, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Http2-Settings", "")
req.Header.Set("Upgrade", "h2c")
req.Header.Set("Connection", "Upgrade, HTTP2-Settings")

resp, err := h1s.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
_, err = io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if got, want := resp.StatusCode, http.StatusOK; got != want {
t.Errorf("resp.StatusCode = %v, want %v", got, want)
}
}