From b81126ce0252a570197f7d9a4764d402652fc558 Mon Sep 17 00:00:00 2001 From: ooooo <297872913@qq.com> Date: Tue, 29 Aug 2023 14:01:34 +0800 Subject: [PATCH] 0.0.5 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 重构代码, 修复 smux 2. 提供 `client` 配置 `mode` ,可选 `http` 和 `websocket`, 默认为 `websocket` 3. 提供 `client` 配置 `smux`,可选 `true` 和 `false`, 默认为 `false` --- RELEASE.md | 12 +++- cmd/http-tunnel-client/main.go | 3 +- cmd/http-tunnel-server/main.go | 2 +- example/example.go | 21 ++++++- tunnel/client.go | 66 +++++++++---------- tunnel/config.go | 86 ++++++++++++++++++------- tunnel/connect_http.go | 10 +-- tunnel/connect_websocket.go | 8 +-- tunnel/copy.go | 23 +++++++ tunnel/options.go | 24 ------- tunnel/server.go | 112 ++++++++++++++++++--------------- 11 files changed, 217 insertions(+), 150 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 7c6105f..d165bb0 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -9,4 +9,14 @@ ## 0.0.3 1. 优化代码, WebSocket 抽象为 Conn -2. 增加多路复用, 配置 smux = "true" \ No newline at end of file +2. 增加多路复用, 配置 smux = "true" + +## 0.0.4 + +1. 优化错误信息 + +## 0.0.5 + +1. 重构代码, 修复 smux +2. 提供 `client` 配置 `mode` ,可选 `http` 和 `websocket`, 默认为 `websocket` +3. 提供 `client` 配置 `smux`,可选 `true` 和 `false`, 默认为 `false` \ No newline at end of file diff --git a/cmd/http-tunnel-client/main.go b/cmd/http-tunnel-client/main.go index 684c90d..32b58f3 100644 --- a/cmd/http-tunnel-client/main.go +++ b/cmd/http-tunnel-client/main.go @@ -26,8 +26,7 @@ func main() { } func StartClient(cc *tunnel.ClientConfig) { - server := tunnel.NewClient(cc.LocalAddr, cc.RemoteAddr, cc.TunnelAddr, cc.TunnelUrl, - tunnel.ClientWithToken(cc.Token), tunnel.ClientWithSMux(cc.Smux)) + server := tunnel.NewClient(cc) err := server.ListenAndServe() if err != nil { log.Error(err) diff --git a/cmd/http-tunnel-server/main.go b/cmd/http-tunnel-server/main.go index d8d15fb..7bcbae9 100644 --- a/cmd/http-tunnel-server/main.go +++ b/cmd/http-tunnel-server/main.go @@ -26,7 +26,7 @@ func main() { } func startServer(sc *tunnel.ServerConfig) { - server := tunnel.NewServer(sc.Addr, sc.Url, tunnel.ServerWithToken(sc.Token)) + server := tunnel.NewServer(sc) err := server.ListenAndServe() if err != nil { log.Error(err) diff --git a/example/example.go b/example/example.go index ce2d0ad..620eacb 100644 --- a/example/example.go +++ b/example/example.go @@ -11,13 +11,28 @@ var ( localAddr = ":8111" remoteAddr = ":30001" tunnelAddr = ":8112" + tunnelUrl = "/" ) func main() { - server := tunnel.NewServer(tunnelAddr, "") + sc := &tunnel.ServerConfig{ + TunnelAddr: tunnelAddr, + TunnelUrl: tunnelUrl, + Token: "", + } + server := tunnel.NewServer(sc) go server.ListenAndServe() - client := tunnel.NewClient(localAddr, remoteAddr, tunnelAddr, "", tunnel.ClientWithSMux("true")) + cc := &tunnel.ClientConfig{ + LocalAddr: localAddr, + RemoteAddr: remoteAddr, + TunnelAddr: tunnelAddr, + TunnelUrl: tunnelUrl, + Token: "", + IsSmux: true, + Mode: tunnel.CONNECT_WEBSOCKET, + } + client := tunnel.NewClient(cc) go client.ListenAndServe() go testServer(remoteAddr) @@ -29,7 +44,7 @@ func testServer(addr string) { mux := http.NewServeMux() mux.HandleFunc("/", index) mux.HandleFunc("/hello", hello) - log.Infoln("listen addr ", addr) + log.Infoln("listen test server addr ", addr) err := http.ListenAndServe(addr, mux) if err != nil { log.Fatalln(err) diff --git a/tunnel/client.go b/tunnel/client.go index 7006c1f..8423c23 100644 --- a/tunnel/client.go +++ b/tunnel/client.go @@ -1,59 +1,38 @@ package tunnel import ( + "fmt" log "github.com/sirupsen/logrus" "github.com/xtaci/smux" "net" + "net/http" ) -var ( - CONNECT_HTTP ConnectMode = "http" - CONNECT_WEBSOCKET ConnectMode = "websocket" -) - -type ConnectMode string - type Client struct { - localAddr string - remoteAddr string - tunnelAddr string - tunnelUrl string - token string - mode ConnectMode - // true or false, default is true - isSmux string + config *ClientConfig + // it is not nil if IsSmux is true smuxSession *smux.Session } -func NewClient(localAddr, remoteAddr, tunnelAddr, tunnelUrl string, options ...ClientOption) *Client { - if localAddr == "" || remoteAddr == "" || tunnelAddr == "" { - panic("localAddr or remoteAddr or tunnelAddr is empty") - } - if tunnelUrl == "" { - tunnelUrl = URL_CONNECT - } +func NewClient(cc *ClientConfig, options ...ClientOption) *Client { c := &Client{ - localAddr: localAddr, - remoteAddr: remoteAddr, - tunnelAddr: tunnelAddr, - tunnelUrl: tunnelUrl, - mode: CONNECT_WEBSOCKET, + config: cc, } for _, option := range options { option(c) } - log.Infof("NewClient localAddr[%s], remoteAddr[%s], tunnelAddr[%s], tunnelUrl[%s]", localAddr, remoteAddr, tunnelAddr, tunnelUrl) + log.Infof("NewClient config %s", c.config) return c } func (c *Client) ListenAndServe() error { - l, err := net.Listen("tcp", c.localAddr) + l, err := net.Listen("tcp", c.config.LocalAddr) if err != nil { - log.Errorf("listen localAddr %s, err: %v", c.localAddr, err) + log.Errorf("listen LocalAddr %s, err: %v", c.config.LocalAddr, err) return err } defer l.Close() - log.Infof("listen localAddr %s", c.localAddr) + log.Infof("listen LocalAddr %s", c.config.LocalAddr) for { conn, err := l.Accept() @@ -66,22 +45,25 @@ func (c *Client) ListenAndServe() error { } func (c *Client) handleConn(conn net.Conn) { + log.Infof("handle conn %s", conn.RemoteAddr()) defer conn.Close() setTCPConnOptions(conn) connectFn := func() net.Conn { var tunnelConn net.Conn - switch c.mode { + switch c.config.Mode { case CONNECT_HTTP: tunnelConn = c.connectWithHTTP() case CONNECT_WEBSOCKET: tunnelConn = c.connectWithWebSocket() + default: + panic("mode is empty") } return tunnelConn } - // support isSmux - if c.isSmux == "true" { + // support smux + if c.config.IsSmux { if c.smuxSession == nil || c.smuxSession.IsClosed() { tunnelConn := connectFn() if tunnelConn == nil { @@ -90,7 +72,7 @@ func (c *Client) handleConn(conn net.Conn) { defer tunnelConn.Close() session, err := smux.Client(tunnelConn, smux.DefaultConfig()) if err != nil { - log.Errorf("new smux client, err: %v", err) + log.Errorf("new IsSmux client, err: %v", err) return } defer session.Close() @@ -101,12 +83,26 @@ func (c *Client) handleConn(conn net.Conn) { log.Errorf("mux open stream, err: %v", err) return } + if stream == nil { + log.Errorf("mux open stream is null") + return + } copyDataOnConn(conn, stream) return } // per connection tunnelConn := connectFn() + if tunnelConn == nil { + return + } defer tunnelConn.Close() copyDataOnConn(conn, tunnelConn) } + +func (c *Client) setHeader(header *http.Header) { + header.Set(HEADER_MODE, string(c.config.Mode)) + header.Set(HEADER_REMOTE_ADDR, c.config.RemoteAddr) + header.Set(HEADER_IS_SMUX, fmt.Sprint(c.config.IsSmux)) + header.Set(HEADER_TOKEN, c.config.Token) +} diff --git a/tunnel/config.go b/tunnel/config.go index 2ad35e0..f35fc4c 100644 --- a/tunnel/config.go +++ b/tunnel/config.go @@ -1,11 +1,31 @@ package tunnel import ( + "encoding/json" "fmt" "github.com/spf13/viper" - "strings" ) +const ( + DEFAULT_TUNNEL_URL = "/" + + CONFIG_COMMON = "common" + CONFIG_LOCAL_ADDR = "local_addr" + CONFIG_REMOTE_ADDR = "remote_addr" + CONFIG_TUNNEL_ADDR = "tunnel_addr" + CONFIG_TUNNEL_URL = "tunnel_url" + CONFIG_TOKEN = "Token" + CONFIG_SMUX = "IsSmux" + CONFIG_MODE = "Mode" +) + +const ( + CONNECT_HTTP ConnectMode = "http" + CONNECT_WEBSOCKET ConnectMode = "websocket" +) + +type ConnectMode string + type ClientConfigs []*ClientConfig type ClientConfig struct { @@ -14,7 +34,8 @@ type ClientConfig struct { TunnelAddr string TunnelUrl string Token string - Smux string + IsSmux bool + Mode ConnectMode } func NewClientConfigsFromFile(configFile string) *ClientConfigs { @@ -25,32 +46,49 @@ func NewClientConfigsFromFile(configFile string) *ClientConfigs { ccs := &ClientConfigs{} // parse common - tunnelAddr := viper.GetString("common.tunnel_addr") - tunnelUrl := viper.GetString("common.tunnel_url") + tunnelAddr := viper.GetString(CONFIG_COMMON + "." + CONFIG_TUNNEL_ADDR) + tunnelUrl := viper.GetString(CONFIG_COMMON + "." + CONFIG_TUNNEL_URL) + isSmux := viper.GetBool(CONFIG_COMMON + "." + CONFIG_SMUX) + mode := viper.GetString(CONFIG_COMMON + "." + CONFIG_MODE) + if tunnelUrl == "" { + tunnelUrl = DEFAULT_TUNNEL_URL + } + if mode == "" { + mode = string(CONNECT_WEBSOCKET) + } // parse special for g, m := range viper.AllSettings() { - if g == "common" { + if g == CONFIG_COMMON { continue } cc := &ClientConfig{} - cc.LocalAddr = getString(m, "local_addr", "") - cc.RemoteAddr = getString(m, "remote_addr", "") - cc.TunnelAddr = getString(m, "tunnel_addr", tunnelAddr) - cc.TunnelUrl = getString(m, "tunnel_url", tunnelUrl) - cc.Token = getString(m, "token", "") - cc.Smux = getString(m, "isSmux", "true") + cc.LocalAddr = getValue(m, CONFIG_LOCAL_ADDR, "") + cc.RemoteAddr = getValue(m, CONFIG_REMOTE_ADDR, "") + cc.TunnelAddr = getValue(m, CONFIG_TUNNEL_ADDR, tunnelAddr) + cc.TunnelUrl = getValue(m, CONFIG_TUNNEL_URL, tunnelUrl) + cc.Token = getValue(m, CONFIG_TOKEN, "") + cc.IsSmux = getValue(m, CONFIG_SMUX, isSmux) + cc.Mode = ConnectMode(getValue(m, CONFIG_MODE, mode)) + if cc.LocalAddr == "" || cc.RemoteAddr == "" || cc.TunnelAddr == "" { + panic(fmt.Sprintf("group %s LocalAddr or RemoteAddr or TunnelAddr is empty", g)) + } *ccs = append(*ccs, cc) } return ccs } +func (c *ClientConfig) String() string { + bytes, _ := json.Marshal(c) + return string(bytes) +} + type ServerConfigs []*ServerConfig type ServerConfig struct { - Addr string - Url string - Token string + TunnelAddr string + TunnelUrl string + Token string } func NewServerConfigsFrom(configFile string) *ServerConfigs { @@ -62,26 +100,26 @@ func NewServerConfigsFrom(configFile string) *ServerConfigs { scs := &ServerConfigs{} for g, m := range viper.AllSettings() { sc := &ServerConfig{} - sc.Addr = getString(m, "tunnel_addr", "") - sc.Url = getString(m, "tunnel_url", "") - sc.Token = getString(m, "token", "") - if sc.Addr == "" { - panic(fmt.Sprintf("group %s config 'Addr' is empty", g)) + sc.TunnelAddr = getValue(m, CONFIG_TUNNEL_ADDR, "") + sc.TunnelUrl = getValue(m, CONFIG_TUNNEL_URL, DEFAULT_TUNNEL_URL) + sc.Token = getValue(m, CONFIG_TOKEN, "") + if sc.TunnelAddr == "" { + panic(fmt.Sprintf("group %s TunnelAddr is empty", g)) } *scs = append(*scs, sc) } return scs } -func splitAddr(addr string) (string, string) { - split := strings.Split(addr, ":") - return split[0], split[1] +func (c *ServerConfig) String() string { + bytes, _ := json.Marshal(c) + return string(bytes) } -func getString(m interface{}, key string, defaultValue string) string { +func getValue[T any](m interface{}, key string, defaultValue T) T { mm := m.(map[string]interface{}) if v, ok := mm[key]; ok { - return v.(string) + return v.(T) } return defaultValue } diff --git a/tunnel/connect_http.go b/tunnel/connect_http.go index 914a867..470dde5 100644 --- a/tunnel/connect_http.go +++ b/tunnel/connect_http.go @@ -10,16 +10,16 @@ import ( func (c *Client) connectWithHTTP() net.Conn { // dial tunnel - tunnelConn, err := net.Dial("tcp", c.tunnelAddr) + tunnelConn, err := net.Dial("tcp", c.config.TunnelAddr) if err != nil { - log.Errorf("dial tunnelAddr %s, err: %v", c.tunnelAddr, err) + log.Errorf("dial TunnelAddr %s, err: %v", c.config.TunnelAddr, err) return nil } // send request - request, _ := http.NewRequest(http.MethodConnect, c.tunnelUrl, nil) - request.Host = c.tunnelAddr - request.Header.Set(HEADER_REMOTE_ADDR, c.remoteAddr) + request, _ := http.NewRequest(http.MethodConnect, c.config.TunnelUrl, nil) + request.Host = c.config.TunnelAddr request.Header.Set("HOST", request.Host) + c.setHeader(&request.Header) err = request.Write(tunnelConn) if err != nil { log.Error("send connect request ", err) diff --git a/tunnel/connect_websocket.go b/tunnel/connect_websocket.go index b88d8a2..9d653fd 100644 --- a/tunnel/connect_websocket.go +++ b/tunnel/connect_websocket.go @@ -12,13 +12,11 @@ import ( func (c *Client) connectWithWebSocket() net.Conn { wsurl := url.URL{ Scheme: "ws", - Host: c.tunnelAddr, - Path: c.tunnelUrl, + Host: c.config.TunnelAddr, + Path: c.config.TunnelUrl, } header := http.Header{} - header.Set(HEADER_REMOTE_ADDR, c.remoteAddr) - header.Set(HEADER_TOKEN, c.token) - header.Set(HEADER_IS_SMUX, c.isSmux) + c.setHeader(&header) wsc, _, err := websocket.DefaultDialer.Dial(wsurl.String(), header) if err != nil { log.Errorf("dial websocket addr %s, err: %v", wsurl.String(), err) diff --git a/tunnel/copy.go b/tunnel/copy.go index 91fd5e1..c92d9a3 100644 --- a/tunnel/copy.go +++ b/tunnel/copy.go @@ -30,3 +30,26 @@ func copyConn(conn1 net.Conn, conn2 net.Conn) { return } } + +//func copyConn(conn1 net.Conn, conn2 net.Conn) { +// buf := make([]byte, 1024*1024) +// var ( +// n int +// err error +// nn int +// ) +// for { +// n, err = conn1.Read(buf) +// if err != nil { +// return +// } +// nn = 0 +// for nn < n { +// nnn, err := conn2.Write(buf[nn:n]) +// if err != nil { +// return +// } +// nn += nnn +// } +// } +//} diff --git a/tunnel/options.go b/tunnel/options.go index 5ac637a..9abec97 100644 --- a/tunnel/options.go +++ b/tunnel/options.go @@ -9,30 +9,6 @@ type ServerOption func(server *Server) type ClientOption func(client *Client) -func ServerWithToken(token string) ServerOption { - return func(server *Server) { - server.token = token - } -} - -func ClientWithToken(token string) ClientOption { - return func(client *Client) { - client.token = token - } -} - -func ClientWithMode(mode ConnectMode) ClientOption { - return func(client *Client) { - client.mode = mode - } -} - -func ClientWithSMux(isSmux string) ClientOption { - return func(client *Client) { - client.isSmux = isSmux - } -} - func setTCPConnOptions(conn net.Conn) { tcpConn := conn.(*net.TCPConn) //tcpConn.SetReadDeadline(time.Now().Add(30 * time.Second)) diff --git a/tunnel/server.go b/tunnel/server.go index 9cd6cb0..504e56c 100644 --- a/tunnel/server.go +++ b/tunnel/server.go @@ -1,88 +1,95 @@ package tunnel import ( - "errors" log "github.com/sirupsen/logrus" "github.com/xtaci/smux" "net" "net/http" + "strconv" ) var ( - URL_CONNECT = "/" - HEADER_REMOTE_ADDR = "REMOTE-ADDR" HEADER_TOKEN = "TOKEN" HEADER_IS_SMUX = "IS_SMUX" - - ErrAuthFail = errors.New("auth fail") + HEADER_MODE = "mode" ) type Server struct { - addr string - url string - token string - l net.Listener + config *ServerConfig + l net.Listener } -func NewServer(addr string, url string, options ...ServerOption) *Server { - if url == "" { - url = URL_CONNECT - } +func NewServer(sc *ServerConfig, options ...ServerOption) *Server { s := &Server{ - addr: addr, - url: url, + config: sc, } for _, option := range options { option(s) } - - log.Infof("NewServer addr[%s], url[%s]", addr, url) + log.Infof("NewServer config %s", s.config) return s } func (s *Server) ListenAndServe() error { - l, err := net.Listen("tcp", s.addr) + l, err := net.Listen("tcp", s.config.TunnelAddr) if err != nil { - log.Error("listen localAddr err ", err) + log.Error("listen LocalAddr %s, err: %v ", s.config.TunnelAddr, err) return err } s.l = l - log.Infof("listen localAddr %s", s.addr) + log.Infof("listen LocalAddr %s", s.config.TunnelAddr) mux := http.NewServeMux() - mux.HandleFunc(s.url, s.Connect) + mux.HandleFunc(s.config.TunnelUrl, s.Connect) err = http.Serve(s.l, mux) if err != nil { - log.Error("serveHTTP err", err) + log.Errorf("serveHTTP err: %v", err) } return err } func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { // auth client - remoteConn := s.auth(w, r) - if remoteConn == nil { + remoteAddr := s.auth(w, r) + if remoteAddr == "" { http.NotFound(w, r) return } - defer remoteConn.Close() - var conn net.Conn - upgrade := r.Header.Get("Upgrade") - if upgrade == "websocket" { - conn = s.connectWithWebSocket(w, r) - } else { - conn = s.connectWithHTTP(w, r) + connectFn := func() net.Conn { + var conn net.Conn + mode := ConnectMode(r.Header.Get(HEADER_MODE)) + switch mode { + case CONNECT_HTTP: + conn = s.connectWithHTTP(w, r) + case CONNECT_WEBSOCKET: + conn = s.connectWithWebSocket(w, r) + default: + panic("header mode is empty") + } + return conn + } + + remoteConnectFn := func() net.Conn { + remoteConn, err := net.Dial("tcp", remoteAddr) + if err != nil { + log.Errorf("dial remote addr %s, err: %v", remoteAddr, err) + return nil + } + return remoteConn } + + // connect + conn := connectFn() if conn == nil { return } defer conn.Close() - // support isSmux - isSmux := r.Header.Get(HEADER_IS_SMUX) - if isSmux == "true" { + // support smux + isSmux, _ := strconv.ParseBool(r.Header.Get(HEADER_IS_SMUX)) + if isSmux { session, err := smux.Server(conn, smux.DefaultConfig()) if err != nil { log.Errorf("new smux server, err: %v", err) @@ -96,32 +103,37 @@ func (s *Server) Connect(w http.ResponseWriter, r *http.Request) { log.Errorf("smux accept stream, err: %v", err) return } + if stream == nil { + log.Errorf("smux accept stream is null") + return + } + remoteConn := remoteConnectFn() + if remoteConn == nil { + continue + } go copyDataOnConn(stream, remoteConn) } } // per connection + remoteConn := remoteConnectFn() + if remoteConn == nil { + return + } copyDataOnConn(conn, remoteConn) } -func (s *Server) auth(w http.ResponseWriter, r *http.Request) net.Conn { - // get remote localAddr +func (s *Server) auth(w http.ResponseWriter, r *http.Request) string { + // verify Token + token := r.Header.Get(HEADER_TOKEN) + if token != s.config.Token { + log.Errorf("http header Token '%s' is err", token) + return "" + } + // get remote LocalAddr remoteAddr := r.Header.Get(HEADER_REMOTE_ADDR) if remoteAddr == "" { log.Errorf("http header '%s' not found", HEADER_REMOTE_ADDR) - return nil - } - // verify token - token := r.Header.Get(HEADER_TOKEN) - if token != s.token { - log.Errorf("http header token '%s' is err", token) - return nil - } - // dial remote addr - remoteConn, err := net.Dial("tcp", remoteAddr) - if err != nil { - log.Errorf("dial remote addr %s, err: %v", remoteAddr, err) - return nil } - return remoteConn + return remoteAddr }