Skip to content

Commit

Permalink
Merge pull request #141 from ngrok/joshngrok_add_ForwardsProto_to_Bind
Browse files Browse the repository at this point in the history
ngrok: add WithAppProtocol
  • Loading branch information
jrobsonchase authored Nov 29, 2023
2 parents 03a24d0 + 78bff0a commit b8b5d7f
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 71 deletions.
24 changes: 24 additions & 0 deletions config/app_protocol.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package config

type appProtocol string

func (ap appProtocol) ApplyHTTP(cfg *httpOptions) {
cfg.commonOpts.ForwardsProto = string(ap)
}

func (ap appProtocol) ApplyLabeled(cfg *labeledOptions) {
cfg.commonOpts.ForwardsProto = string(ap)
}

// WithAppProtocol declares the protocol that the upstream service speaks.
// This may be used by the ngrok edge to make decisions regarding protocol
// upgrades or downgrades.
//
// Currently, `http2` is the only valid string, and will cause connections
// received from HTTP endpoints to always use HTTP/2.
func WithAppProtocol(proto string) interface {
HTTPEndpointOption
LabeledTunnelOption
} {
return appProtocol(proto)
}
5 changes: 5 additions & 0 deletions config/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ type commonOpts struct {
// bearing on tunnel behavior.
// If not set, defaults to a URI in the format `app://hostname/path/to/executable?pid=12345`
ForwardsTo string

// The protocol that's forwarded from the ngrok edge.
// Currently only relevant for HTTP/1.1 vs HTTP/2, since there's a potential
// change-of-protocol happening at our edge.
ForwardsProto string
}

type CommonOptionsFunc func(cfg *commonOpts)
Expand Down
4 changes: 4 additions & 0 deletions config/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ func (cfg *httpOptions) toProtoConfig() *proto.HTTPEndpoint {
return opts
}

func (cfg httpOptions) ForwardsProto() string {
return cfg.commonOpts.ForwardsProto
}

func (cfg httpOptions) ForwardsTo() string {
return cfg.commonOpts.getForwardsTo()
}
Expand Down
4 changes: 4 additions & 0 deletions config/labeled.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ func WithLabel(label, value string) LabeledTunnelOption {
})
}

func (cfg labeledOptions) ForwardsProto() string {
return cfg.commonOpts.ForwardsProto
}

func (cfg labeledOptions) ForwardsTo() string {
return cfg.commonOpts.getForwardsTo()
}
Expand Down
4 changes: 4 additions & 0 deletions config/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ func (cfg tcpOptions) ForwardsTo() string {
return cfg.commonOpts.getForwardsTo()
}

func (cfg tcpOptions) ForwardsProto() string {
return "" // Not supported for TCP
}

func (cfg *tcpOptions) WithForwardsTo(url *url.URL) {
cfg.commonOpts.ForwardsTo = url.Host
}
Expand Down
4 changes: 4 additions & 0 deletions config/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ func (cfg *tlsOptions) toProtoConfig() *proto.TLSEndpoint {
return opts
}

func (cfg tlsOptions) ForwardsProto() string {
return "" // Not supported for TLS
}

func (cfg tlsOptions) ForwardsTo() string {
return cfg.commonOpts.getForwardsTo()
}
Expand Down
1 change: 1 addition & 0 deletions config/tunnel_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Tunnel interface {
// the public interface with internal details.
type tunnelConfigPrivate interface {
ForwardsTo() string
ForwardsProto() string
Extra() proto.BindExtra
Proto() string
Opts() any
Expand Down
26 changes: 14 additions & 12 deletions internal/tunnel/client/raw_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (

type RawSession interface {
Auth(id string, extra proto.AuthExtra) (proto.AuthResp, error)
Listen(proto string, opts any, extra proto.BindExtra, id string, forwardsTo string) (proto.BindResp, error)
ListenLabel(labels map[string]string, metadata string, forwardsTo string) (proto.StartTunnelWithLabelResp, error)
Listen(proto string, opts any, extra proto.BindExtra, id string, forwardsTo string, forwardsProto string) (proto.BindResp, error)
ListenLabel(labels map[string]string, metadata string, forwardsTo string, forwardsProto string) (proto.StartTunnelWithLabelResp, error)
Unlisten(id string) (proto.UnbindResp, error)
Accept() (netx.LoggedConn, error)

Expand Down Expand Up @@ -95,13 +95,14 @@ func (s *rawSession) Auth(id string, extra proto.AuthExtra) (resp proto.AuthResp
// opts are protocol-specific options for listening.
// extra is an opaque struct useful for passing application-specific data.
// id is an session-unique identifier, if empty it will be assigned for you
func (s *rawSession) Listen(protocol string, opts any, extra proto.BindExtra, id string, forwardsTo string) (resp proto.BindResp, err error) {
func (s *rawSession) Listen(protocol string, opts any, extra proto.BindExtra, id string, forwardsTo string, forwardsProto string) (resp proto.BindResp, err error) {
req := proto.Bind{
ClientID: id,
Proto: protocol,
Opts: opts,
Extra: extra,
ForwardsTo: forwardsTo,
ClientID: id,
Proto: protocol,
Opts: opts,
Extra: extra,
ForwardsTo: forwardsTo,
ForwardsProto: forwardsProto,
}
err = s.rpc(proto.BindReq, &req, &resp)
if err != nil {
Expand All @@ -115,11 +116,12 @@ func (s *rawSession) Listen(protocol string, opts any, extra proto.BindExtra, id
}

// ListenLabel sends a listen message to the server and returns the server's response
func (s *rawSession) ListenLabel(labels map[string]string, metadata string, forwardsTo string) (resp proto.StartTunnelWithLabelResp, err error) {
func (s *rawSession) ListenLabel(labels map[string]string, metadata string, forwardsTo string, forwardsProto string) (resp proto.StartTunnelWithLabelResp, err error) {
req := proto.StartTunnelWithLabel{
Labels: labels,
Metadata: metadata,
ForwardsTo: forwardsTo,
Labels: labels,
Metadata: metadata,
ForwardsTo: forwardsTo,
ForwardsProto: forwardsProto,
}
err = s.rpc(proto.StartTunnelWithLabelReq, &req, &resp)
return
Expand Down
12 changes: 6 additions & 6 deletions internal/tunnel/client/reconnecting.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,16 @@ func (s *swapRaw) Auth(id string, extra proto.AuthExtra) (resp proto.AuthResp, e
return proto.AuthResp{}, ErrSessionNotReady
}

func (s *swapRaw) Listen(protocol string, opts any, extra proto.BindExtra, id string, forwardsTo string) (resp proto.BindResp, err error) {
func (s *swapRaw) Listen(protocol string, opts any, extra proto.BindExtra, id string, forwardsTo string, forwardsProto string) (resp proto.BindResp, err error) {
if raw := s.get(); raw != nil {
return raw.Listen(protocol, opts, extra, id, forwardsTo)
return raw.Listen(protocol, opts, extra, id, forwardsTo, forwardsProto)
}
return proto.BindResp{}, ErrSessionNotReady
}

func (s *swapRaw) ListenLabel(labels map[string]string, metadata string, forwardsTo string) (resp proto.StartTunnelWithLabelResp, err error) {
func (s *swapRaw) ListenLabel(labels map[string]string, metadata string, forwardsTo string, forwardsProto string) (resp proto.StartTunnelWithLabelResp, err error) {
if raw := s.get(); raw != nil {
return raw.ListenLabel(labels, metadata, forwardsTo)
return raw.ListenLabel(labels, metadata, forwardsTo, forwardsProto)
}
return proto.StartTunnelWithLabelResp{}, ErrSessionNotReady
}
Expand Down Expand Up @@ -236,7 +236,7 @@ func (s *reconnectingSession) connect(acceptErr error) error {

var respErr string
if tCfg.Labels != nil {
resp, err := raw.ListenLabel(tCfg.Labels, tCfg.Metadata, t.ForwardsTo())
resp, err := raw.ListenLabel(tCfg.Labels, tCfg.Metadata, t.ForwardsTo(), t.ForwardsProto())
if err != nil {
return err
}
Expand All @@ -250,7 +250,7 @@ func (s *reconnectingSession) connect(acceptErr error) error {
newTunnels[oldID] = t
}
} else {
resp, err := raw.Listen(tCfg.ConfigProto, tCfg.Opts, t.bindExtra, t.ID(), t.ForwardsTo())
resp, err := raw.Listen(tCfg.ConfigProto, tCfg.Opts, t.bindExtra, t.ID(), t.ForwardsTo(), t.ForwardsProto())
if err != nil {
return err
}
Expand Down
34 changes: 17 additions & 17 deletions internal/tunnel/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,20 @@ type Session interface {
//
// Applications will typically prefer to call the protocol-specific methods like
// ListenHTTP, ListenTCP, etc.
Listen(protocol string, opts any, extra proto.BindExtra, forwardsTo string) (Tunnel, error)
Listen(protocol string, opts any, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error)

// Listen negotiates with the server to create a new remote listen for the
// given labels. It returns a *Tunnel on success from which the caller can
// accept new connections over the listen.
ListenLabel(labels map[string]string, metadata string, forwardsTo string) (Tunnel, error)
ListenLabel(labels map[string]string, metadata string, forwardsTo string, forwardsProto string) (Tunnel, error)

// Convenience methods

// ListenHTTP listens on a new HTTP endpoint
ListenHTTP(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error)
ListenHTTP(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error)

// ListenHTTP listens on a new HTTPS endpoint
ListenHTTPS(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error)
ListenHTTPS(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error)

// ListenTCP listens on a remote TCP endpoint
ListenTCP(opts *proto.TCPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error)
Expand Down Expand Up @@ -116,8 +116,8 @@ func (s *session) Heartbeat() (time.Duration, error) {
return s.raw.Heartbeat()
}

func (s *session) Listen(protocol string, opts any, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
resp, err := s.raw.Listen(protocol, opts, extra, "", forwardsTo)
func (s *session) Listen(protocol string, opts any, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error) {
resp, err := s.raw.Listen(protocol, opts, extra, "", forwardsTo, forwardsProto)
if err != nil {
return nil, err
}
Expand All @@ -128,16 +128,16 @@ func (s *session) Listen(protocol string, opts any, extra proto.BindExtra, forwa
}

// make tunnel
t := newTunnel(resp, extra, s, forwardsTo)
t := newTunnel(resp, extra, s, forwardsTo, forwardsProto)

// add to tunnel registry
s.addTunnel(resp.ClientID, t)

return t, nil
}

func (s *session) ListenLabel(labels map[string]string, metadata string, forwardsTo string) (Tunnel, error) {
resp, err := s.raw.ListenLabel(labels, metadata, forwardsTo)
func (s *session) ListenLabel(labels map[string]string, metadata string, forwardsTo string, forwardsProto string) (Tunnel, error) {
resp, err := s.raw.ListenLabel(labels, metadata, forwardsTo, forwardsProto)
if err != nil {
return nil, err
}
Expand All @@ -148,32 +148,32 @@ func (s *session) ListenLabel(labels map[string]string, metadata string, forward
}

// make tunnel
t := newTunnelLabel(resp, metadata, labels, s, forwardsTo)
t := newTunnelLabel(resp, metadata, labels, s, forwardsTo, forwardsProto)

// add to tunnel registry
s.addTunnel(resp.ID, t)

return t, nil
}

func (s *session) ListenHTTP(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
return s.Listen("http", opts, extra, forwardsTo)
func (s *session) ListenHTTP(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error) {
return s.Listen("http", opts, extra, forwardsTo, forwardsProto)
}

func (s *session) ListenHTTPS(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
return s.Listen("https", opts, extra, forwardsTo)
func (s *session) ListenHTTPS(opts *proto.HTTPEndpoint, extra proto.BindExtra, forwardsTo string, forwardsProto string) (Tunnel, error) {
return s.Listen("https", opts, extra, forwardsTo, forwardsProto)
}

func (s *session) ListenTCP(opts *proto.TCPEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
return s.Listen("tcp", opts, extra, forwardsTo)
return s.Listen("tcp", opts, extra, forwardsTo, "")
}

func (s *session) ListenTLS(opts *proto.TLSEndpoint, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
return s.Listen("tls", opts, extra, forwardsTo)
return s.Listen("tls", opts, extra, forwardsTo, "")
}

func (s *session) ListenSSH(opts *proto.SSHOptions, extra proto.BindExtra, forwardsTo string) (Tunnel, error) {
return s.Listen("ssh", opts, extra, forwardsTo)
return s.Listen("ssh", opts, extra, forwardsTo, "")
}

func (s *session) SrvInfo() (proto.SrvInfoResp, error) {
Expand Down
59 changes: 34 additions & 25 deletions internal/tunnel/client/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ type ProxyConn struct {
// A Tunnel is a net.Listener that Accept()'s connections from a
// remote machine.
type tunnel struct {
id atomic.Value
configProto string
url string
opts any
token string
bindExtra proto.BindExtra
labels map[string]string
forwardsTo string
id atomic.Value
configProto string
url string
opts any
token string
bindExtra proto.BindExtra
labels map[string]string
forwardsTo string
forwardsProto string

accept chan *ProxyConn // new connections come on this channel
unlisten func() error // call this function to close the tunnel
Expand All @@ -42,36 +43,38 @@ type tunnel struct {
shut shutdown // for clean shutdowns
}

func newTunnel(resp proto.BindResp, extra proto.BindExtra, s *session, forwardsTo string) *tunnel {
func newTunnel(resp proto.BindResp, extra proto.BindExtra, s *session, forwardsTo string, forwardsProto string) *tunnel {
id := atomic.Value{}
id.Store(resp.ClientID)
return &tunnel{
id: id,
configProto: resp.Proto,
url: resp.URL,
opts: resp.Opts,
token: resp.Extra.Token,
bindExtra: extra, // this makes the reconnecting session a little easier
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ClientID) },
forwardsTo: forwardsTo,
closeError: errors.New("Listener closed"),
id: id,
configProto: resp.Proto,
url: resp.URL,
opts: resp.Opts,
token: resp.Extra.Token,
bindExtra: extra, // this makes the reconnecting session a little easier
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ClientID) },
forwardsTo: forwardsTo,
forwardsProto: forwardsProto,
closeError: errors.New("Listener closed"),
}
}

func newTunnelLabel(resp proto.StartTunnelWithLabelResp, metadata string, labels map[string]string, s *session, forwardsTo string) *tunnel {
func newTunnelLabel(resp proto.StartTunnelWithLabelResp, metadata string, labels map[string]string, s *session, forwardsTo string, forwardsProto string) *tunnel {
id := atomic.Value{}
id.Store(resp.ID)
return &tunnel{
id: id,
bindExtra: proto.BindExtra{
Metadata: metadata,
}, // this makes the reconnecting session a little easier
labels: labels,
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ID) },
forwardsTo: forwardsTo,
closeError: errors.New("Listener closed"),
labels: labels,
accept: make(chan *ProxyConn),
unlisten: func() error { return s.unlisten(resp.ID) },
forwardsTo: forwardsTo,
forwardsProto: forwardsProto,
closeError: errors.New("Listener closed"),
}
}

Expand Down Expand Up @@ -115,6 +118,12 @@ func (t *tunnel) Addr() net.Addr {
return t.RemoteBindConfig()
}

// ForwardsProto returns the protocol of the upstream that the ngrok agent
// adverstises to the edge.
func (t *tunnel) ForwardsProto() string {
return t.forwardsProto
}

// ForwardsTo returns the address of the upstream the ngrok agent will
// forward proxied connections to
func (t *tunnel) ForwardsTo() string {
Expand Down
Loading

0 comments on commit b8b5d7f

Please sign in to comment.