Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ngrok: add WithAppProtocol #141

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions config/app_protocol.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package config

type appProtocol string

func (ap appProtocol) ApplyHTTP(cfg *httpOptions) {
cfg.commonOpts.ForwardsProto = string(ap)
}

func (ap appProtocol) ApplyLabeled(cfg *labeledOptions) {
cfg.commonOpts.ForwardsProto = string(ap)
}

// WithAppProtocol declares the protocol that the upstream service speaks.
// This may be used by the ngrok edge to make decisions regarding protocol
// upgrades or downgrades.
//
// Currently, `http2` is the only valid string, and will cause connections
bobzilladev marked this conversation as resolved.
Show resolved Hide resolved
// received from HTTP endpoints to always use HTTP/2.
jrobsonchase marked this conversation as resolved.
Show resolved Hide resolved
func WithAppProtocol(proto string) interface {
HTTPEndpointOption
LabeledTunnelOption
} {
return appProtocol(proto)
}
5 changes: 5 additions & 0 deletions config/common.go
Original file line number Diff line number Diff line change
@@ -12,6 +12,11 @@ type commonOpts struct {
// bearing on tunnel behavior.
// If not set, defaults to a URI in the format `app://hostname/path/to/executable?pid=12345`
ForwardsTo string

// The protocol that's forwarded from the ngrok edge.
// Currently only relevant for HTTP/1.1 vs HTTP/2, since there's a potential
// change-of-protocol happening at our edge.
ForwardsProto string
}

type CommonOptionsFunc func(cfg *commonOpts)
4 changes: 4 additions & 0 deletions config/http.go
Original file line number Diff line number Diff line change
@@ -128,6 +128,10 @@ func (cfg *httpOptions) toProtoConfig() *proto.HTTPEndpoint {
return opts
}

func (cfg httpOptions) ForwardsProto() string {
return cfg.commonOpts.ForwardsProto
}

func (cfg httpOptions) ForwardsTo() string {
return cfg.commonOpts.getForwardsTo()
}
4 changes: 4 additions & 0 deletions config/labeled.go
Original file line number Diff line number Diff line change
@@ -52,6 +52,10 @@ func WithLabel(label, value string) LabeledTunnelOption {
})
}

func (cfg labeledOptions) ForwardsProto() string {
return cfg.commonOpts.ForwardsProto
}

func (cfg labeledOptions) ForwardsTo() string {
return cfg.commonOpts.getForwardsTo()
}
4 changes: 4 additions & 0 deletions config/tcp.go
Original file line number Diff line number Diff line change
@@ -58,6 +58,10 @@ func (cfg tcpOptions) ForwardsTo() string {
return cfg.commonOpts.getForwardsTo()
}

func (cfg tcpOptions) ForwardsProto() string {
return "" // Not supported for TCP
}

func (cfg *tcpOptions) WithForwardsTo(url *url.URL) {
cfg.commonOpts.ForwardsTo = url.Host
}
4 changes: 4 additions & 0 deletions config/tls.go
Original file line number Diff line number Diff line change
@@ -84,6 +84,10 @@ func (cfg *tlsOptions) toProtoConfig() *proto.TLSEndpoint {
return opts
}

func (cfg tlsOptions) ForwardsProto() string {
return "" // Not supported for TLS
}

func (cfg tlsOptions) ForwardsTo() string {
return cfg.commonOpts.getForwardsTo()
}
1 change: 1 addition & 0 deletions config/tunnel_config.go
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ type Tunnel interface {
// the public interface with internal details.
type tunnelConfigPrivate interface {
ForwardsTo() string
ForwardsProto() string
Extra() proto.BindExtra
Proto() string
Opts() any
26 changes: 14 additions & 12 deletions internal/tunnel/client/raw_session.go
Original file line number Diff line number Diff line change
@@ -19,8 +19,8 @@ import (

type RawSession interface {
Auth(id string, extra proto.AuthExtra) (proto.AuthResp, error)
Listen(proto string, opts any, extra proto.BindExtra, id string, forwardsTo string) (proto.BindResp, error)
ListenLabel(labels map[string]string, metadata string, forwardsTo string) (proto.StartTunnelWithLabelResp, error)
Listen(proto string, opts any, extra proto.BindExtra, id string, forwardsTo string, forwardsProto string) (proto.BindResp, error)
ListenLabel(labels map[string]string, metadata string, forwardsTo string, forwardsProto string) (proto.StartTunnelWithLabelResp, error)
Unlisten(id string) (proto.UnbindResp, error)
Accept() (netx.LoggedConn, error)

@@ -95,13 +95,14 @@ func (s *rawSession) Auth(id string, extra proto.AuthExtra) (resp proto.AuthResp
// opts are protocol-specific options for listening.
// extra is an opaque struct useful for passing application-specific data.
// id is an session-unique identifier, if empty it will be assigned for you
func (s *rawSession) Listen(protocol string, opts any, extra proto.BindExtra, id string, forwardsTo string) (resp proto.BindResp, err error) {
func (s *rawSession) Listen(protocol string, opts any, extra proto.BindExtra, id string, forwardsTo string, forwardsProto string) (resp proto.BindResp, err error) {
req := proto.Bind{
ClientID: id,
Proto: protocol,
Opts: opts,
Extra: extra,
ForwardsTo: forwardsTo,
ClientID: id,
Proto: protocol,
Opts: opts,
Extra: extra,
ForwardsTo: forwardsTo,
ForwardsProto: forwardsProto,
}
err = s.rpc(proto.BindReq, &req, &resp)
if err != nil {
@@ -115,11 +116,12 @@ func (s *rawSession) Listen(protocol string, opts any, extra proto.BindExtra, id
}

// ListenLabel sends a listen message to the server and returns the server's response
func (s *rawSession) ListenLabel(labels map[string]string, metadata string, forwardsTo string) (resp proto.StartTunnelWithLabelResp, err error) {
func (s *rawSession) ListenLabel(labels map[string]string, metadata string, forwardsTo string, forwardsProto string) (resp proto.StartTunnelWithLabelResp, err error) {
req := proto.StartTunnelWithLabel{
Labels: labels,
Metadata: metadata,
ForwardsTo: forwardsTo,
Labels: labels,
Metadata: metadata,
ForwardsTo: forwardsTo,
ForwardsProto: forwardsProto,
}
err = s.rpc(proto.StartTunnelWithLabelReq, &req, &resp)
return
12 changes: 6 additions & 6 deletions internal/tunnel/client/reconnecting.go
Original file line number Diff line number Diff line change
@@ -39,16 +39,16 @@ func (s *swapRaw) Auth(id string, extra proto.AuthExtra) (resp proto.AuthResp, e
return proto.AuthResp{}, ErrSessionNotReady
}

func (s *swapRaw) Listen(protocol string, opts any, extra proto.BindExtra, id string, forwardsTo string) (resp proto.BindResp, err error) {
func (s *swapRaw) Listen(protocol string, opts any, extra proto.BindExtra, id string, forwardsTo string, forwardsProto string) (resp proto.BindResp, err error) {
if raw := s.get(); raw != nil {
return raw.Listen(protocol, opts, extra, id, forwardsTo)
return raw.Listen(protocol, opts, extra, id, forwardsTo, forwardsProto)
}
return proto.BindResp{}, ErrSessionNotReady
}

func (s *swapRaw) ListenLabel(labels map[string]string, metadata string, forwardsTo string) (resp proto.StartTunnelWithLabelResp, err error) {
func (s *swapRaw) ListenLabel(labels map[string]string, metadata string, forwardsTo string, forwardsProto string) (resp proto.StartTunnelWithLabelResp, err error) {
if raw := s.get(); raw != nil {
return raw.ListenLabel(labels, metadata, forwardsTo)
return raw.ListenLabel(labels, metadata, forwardsTo, forwardsProto)
}
return proto.StartTunnelWithLabelResp{}, ErrSessionNotReady
}
@@ -236,7 +236,7 @@ func (s *reconnectingSession) connect(acceptErr error) error {

var respErr string
if tCfg.Labels != nil {
resp, err := raw.ListenLabel(tCfg.Labels, tCfg.Metadata, t.ForwardsTo())
resp, err := raw.ListenLabel(tCfg.Labels, tCfg.Metadata, t.ForwardsTo(), t.ForwardsProto())
if err != nil {
return err
}
@@ -250,7 +250,7 @@ func (s *reconnectingSession) connect(acceptErr error) error {
newTunnels[oldID] = t
}
} else {
resp, err := raw.Listen(tCfg.ConfigProto, tCfg.Opts, t.bindExtra, t.ID(), t.ForwardsTo())
resp, err := raw.Listen(tCfg.ConfigProto, tCfg.Opts, t.bindExtra, t.ID(), t.ForwardsTo(), t.ForwardsProto())
if err != nil {
return err
}
34 changes: 17 additions & 17 deletions internal/tunnel/client/session.go
Original file line number Diff line number Diff line change
@@ -38,20 +38,20 @@ type Session interface {
//
// Applications will typically prefer to call the protocol-specific methods like
// ListenHTTP, ListenTCP, etc.
Listen(protocol string, opts any, extra proto.BindExtra, forwardsTo string) (Tunnel, error)
Listen(protocol string, opts any, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error)

// Listen negotiates with the server to create a new remote listen for the
// given labels. It returns a *Tunnel on success from which the caller can
// accept new connections over the listen.
ListenLabel(labels map[string]string, metadata string, forwardsTo string) (Tunnel, error)
ListenLabel(labels map[string]string, metadata string, forwardsTo string, forwardsProto string) (Tunnel, error)

// Convenience methods

// ListenHTTP listens on a new HTTP endpoint
ListenHTTP(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error)
ListenHTTP(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error)

// ListenHTTP listens on a new HTTPS endpoint
ListenHTTPS(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error)
ListenHTTPS(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error)

// ListenTCP listens on a remote TCP endpoint
ListenTCP(opts *proto.TCPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error)
@@ -116,8 +116,8 @@ func (s *session) Heartbeat() (time.Duration, error) {
return s.raw.Heartbeat()
}

func (s *session) Listen(protocol string, opts any, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
resp, err := s.raw.Listen(protocol, opts, extra, "", forwardsTo)
func (s *session) Listen(protocol string, opts any, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error) {
resp, err := s.raw.Listen(protocol, opts, extra, "", forwardsTo, forwardsProto)
if err != nil {
return nil, err
}
@@ -128,16 +128,16 @@ func (s *session) Listen(protocol string, opts any, extra proto.BindExtra, forwa
}

// make tunnel
t := newTunnel(resp, extra, s, forwardsTo)
t := newTunnel(resp, extra, s, forwardsTo, forwardsProto)

// add to tunnel registry
s.addTunnel(resp.ClientID, t)

return t, nil
}

func (s *session) ListenLabel(labels map[string]string, metadata string, forwardsTo string) (Tunnel, error) {
resp, err := s.raw.ListenLabel(labels, metadata, forwardsTo)
func (s *session) ListenLabel(labels map[string]string, metadata string, forwardsTo string, forwardsProto string) (Tunnel, error) {
resp, err := s.raw.ListenLabel(labels, metadata, forwardsTo, forwardsProto)
if err != nil {
return nil, err
}
@@ -148,32 +148,32 @@ func (s *session) ListenLabel(labels map[string]string, metadata string, forward
}

// make tunnel
t := newTunnelLabel(resp, metadata, labels, s, forwardsTo)
t := newTunnelLabel(resp, metadata, labels, s, forwardsTo, forwardsProto)

// add to tunnel registry
s.addTunnel(resp.ID, t)

return t, nil
}

func (s *session) ListenHTTP(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
return s.Listen("http", opts, extra, forwardsTo)
func (s *session) ListenHTTP(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error) {
return s.Listen("http", opts, extra, forwardsTo, forwardsProto)
}

func (s *session) ListenHTTPS(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
return s.Listen("https", opts, extra, forwardsTo)
func (s *session) ListenHTTPS(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error) {
return s.Listen("https", opts, extra, forwardsTo, forwardsProto)
}

func (s *session) ListenTCP(opts *proto.TCPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
return s.Listen("tcp", opts, extra, forwardsTo)
return s.Listen("tcp", opts, extra, forwardsTo, "")
}

func (s *session) ListenTLS(opts *proto.TLSEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
return s.Listen("tls", opts, extra, forwardsTo)
return s.Listen("tls", opts, extra, forwardsTo, "")
}

func (s *session) ListenSSH(opts *proto.SSHOptions, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
return s.Listen("ssh", opts, extra, forwardsTo)
return s.Listen("ssh", opts, extra, forwardsTo, "")
}

func (s *session) SrvInfo() (proto.SrvInfoResp, error) {
59 changes: 34 additions & 25 deletions internal/tunnel/client/tunnel.go
Original file line number Diff line number Diff line change
@@ -26,14 +26,15 @@ type ProxyConn struct {
// A Tunnel is a net.Listener that Accept()'s connections from a
// remote machine.
type tunnel struct {
id atomic.Value
configProto string
url string
opts any
token string
bindExtra proto.BindExtra
labels map[string]string
forwardsTo string
id atomic.Value
configProto string
url string
opts any
token string
bindExtra proto.BindExtra
labels map[string]string
forwardsTo string
forwardsProto string

accept chan *ProxyConn // new connections come on this channel
unlisten func() error // call this function to close the tunnel
@@ -42,36 +43,38 @@ type tunnel struct {
shut shutdown // for clean shutdowns
}

func newTunnel(resp proto.BindResp, extra proto.BindExtra, s *session, forwardsTo string) *tunnel {
func newTunnel(resp proto.BindResp, extra proto.BindExtra, s *session, forwardsTo string, forwardsProto string) *tunnel {
id := atomic.Value{}
id.Store(resp.ClientID)
return &tunnel{
id: id,
configProto: resp.Proto,
url: resp.URL,
opts: resp.Opts,
token: resp.Extra.Token,
bindExtra: extra, // this makes the reconnecting session a little easier
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ClientID) },
forwardsTo: forwardsTo,
closeError: errors.New("Listener closed"),
id: id,
configProto: resp.Proto,
url: resp.URL,
opts: resp.Opts,
token: resp.Extra.Token,
bindExtra: extra, // this makes the reconnecting session a little easier
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ClientID) },
forwardsTo: forwardsTo,
forwardsProto: forwardsProto,
closeError: errors.New("Listener closed"),
}
}

func newTunnelLabel(resp proto.StartTunnelWithLabelResp, metadata string, labels map[string]string, s *session, forwardsTo string) *tunnel {
func newTunnelLabel(resp proto.StartTunnelWithLabelResp, metadata string, labels map[string]string, s *session, forwardsTo string, forwardsProto string) *tunnel {
id := atomic.Value{}
id.Store(resp.ID)
return &tunnel{
id: id,
bindExtra: proto.BindExtra{
Metadata: metadata,
}, // this makes the reconnecting session a little easier
labels: labels,
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ID) },
forwardsTo: forwardsTo,
closeError: errors.New("Listener closed"),
labels: labels,
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ID) },
forwardsTo: forwardsTo,
forwardsProto: forwardsProto,
closeError: errors.New("Listener closed"),
}
}

@@ -115,6 +118,12 @@ func (t *tunnel) Addr() net.Addr {
return t.RemoteBindConfig()
}

// ForwardsProto returns the protocol of the upstream that the ngrok agent
// adverstises to the edge.
func (t *tunnel) ForwardsProto() string {
return t.forwardsProto
}

// ForwardsTo returns the address of the upstream the ngrok agent will
// forward proxied connections to
func (t *tunnel) ForwardsTo() string {
20 changes: 11 additions & 9 deletions internal/tunnel/proto/msg.go
Original file line number Diff line number Diff line change
@@ -200,12 +200,13 @@ type AuthRespExtra struct {
// A client sends this message to the server over a new stream
// to request the server bind a remote port/hostname on the client's behalf.
type Bind struct {
ID string `json:"-"`
ClientID string `json:"Id"` // a session-unique bind ID generated by the client, if empty, one is generated
Proto string // the protocol to bind (one of 'http', 'https', 'tcp', 'tls', 'ssh')
ForwardsTo string // the address of the upstream service the ngrok agent will forward to
Opts any // options for the bind - protocol dependent
Extra BindExtra // anything extra the application wants to send
ID string `json:"-"`
ClientID string `json:"Id"` // a session-unique bind ID generated by the client, if empty, one is generated
Proto string // the protocol to bind (one of 'http', 'https', 'tcp', 'tls', 'ssh')
ForwardsTo string // the address of the upstream service the ngrok agent will forward to
ForwardsProto string // the L7 protocol the upstream service is expecting (one of '', 'http1', 'http2')
Opts any // options for the bind - protocol dependent
Extra BindExtra // anything extra the application wants to send
}

type BindExtra struct {
@@ -233,9 +234,10 @@ type BindRespExtra struct {
// to request the server start a new tunnel with the given labels on the client's behalf.
type StartTunnelWithLabel struct {
// ID string `json:"-"` // a session-unique bind ID generated by the client, if empty, one is generated
Labels map[string]string // labels for tunnel group membership
ForwardsTo string // the address of the upstream service the ngrok agent will forward to
Metadata string
Labels map[string]string // labels for tunnel group membership
ForwardsTo string // the address of the upstream service the ngrok agent will forward to
ForwardsProto string // the L7 protocol the upstream service is expecting (one of '', 'http1', 'http2')
Metadata string
}

// The server responds with a StartTunnelWithLabelResp message to notify the client of the
4 changes: 2 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
@@ -800,9 +800,9 @@ func (s *sessionImpl) Listen(ctx context.Context, cfg config.Tunnel) (Tunnel, er

extra := tunnelCfg.Extra()
if tunnelCfg.Proto() != "" {
tunnel, err = s.inner().Listen(tunnelCfg.Proto(), tunnelCfg.Opts(), extra, tunnelCfg.ForwardsTo())
tunnel, err = s.inner().Listen(tunnelCfg.Proto(), tunnelCfg.Opts(), extra, tunnelCfg.ForwardsTo(), tunnelCfg.ForwardsProto())
} else {
tunnel, err = s.inner().ListenLabel(tunnelCfg.Labels(), extra.Metadata, tunnelCfg.ForwardsTo())
tunnel, err = s.inner().ListenLabel(tunnelCfg.Labels(), extra.Metadata, tunnelCfg.ForwardsTo(), tunnelCfg.ForwardsProto())
}

impl := &tunnelImpl{
1 change: 1 addition & 0 deletions tunnel_config.go
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ import (
// Duplicated from config/tunnel_config.go
type tunnelConfigPrivate interface {
ForwardsTo() string
ForwardsProto() string
Extra() proto.BindExtra
Proto() string
Opts() any