diff --git a/common/control/bind.go b/common/control/bind.go index 4a791853..b8451db6 100644 --- a/common/control/bind.go +++ b/common/control/bind.go @@ -1,10 +1,9 @@ package control import ( - "os" - "runtime" "syscall" + E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) @@ -25,38 +24,12 @@ func BindToInterfaceFunc(finder InterfaceFinder, block func(network string, addr } } -const useInterfaceName = runtime.GOOS == "linux" || runtime.GOOS == "android" - func BindToInterface0(finder InterfaceFinder, conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { if interfaceName == "" && interfaceIndex == -1 { - return nil + return E.New("interface not found: ", interfaceName) } if addr := M.ParseSocksaddr(address).Addr; addr.IsValid() && N.IsVirtual(addr) { return nil } - if interfaceName != "" && useInterfaceName || interfaceIndex != -1 && !useInterfaceName { - return bindToInterface(conn, network, address, interfaceName, interfaceIndex) - } - if finder == nil { - return os.ErrInvalid - } - var err error - if useInterfaceName { - interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex) - } else { - interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) - } - if err != nil { - return err - } - if useInterfaceName { - if interfaceName == "" { - return nil - } - } else { - if interfaceIndex == -1 { - return nil - } - } - return bindToInterface(conn, network, address, interfaceName, interfaceIndex) + return bindToInterface(conn, network, address, finder, interfaceName, interfaceIndex) } diff --git a/common/control/bind_darwin.go b/common/control/bind_darwin.go index 8262ac7f..f5be42d5 100644 --- a/common/control/bind_darwin.go +++ b/common/control/bind_darwin.go @@ -1,16 +1,24 @@ package control import ( + "os" "syscall" "golang.org/x/sys/unix" ) -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { - if interfaceIndex == -1 { - return nil - } +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error { return Raw(conn, func(fd uintptr) error { + var err error + if interfaceIndex == -1 { + if finder == nil { + return os.ErrInvalid + } + interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + if err != nil { + return err + } + } switch network { case "tcp6", "udp6": return unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, interfaceIndex) diff --git a/common/control/bind_linux.go b/common/control/bind_linux.go index 6ebca49d..51529a00 100644 --- a/common/control/bind_linux.go +++ b/common/control/bind_linux.go @@ -1,13 +1,48 @@ package control import ( + "os" "syscall" + "github.com/sagernet/sing/common/atomic" + E "github.com/sagernet/sing/common/exceptions" + "golang.org/x/sys/unix" ) -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { +var ifIndexDisabled atomic.Bool + +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error { return Raw(conn, func(fd uintptr) error { + var err error + if !ifIndexDisabled.Load() { + if interfaceIndex == -1 { + if finder == nil { + return os.ErrInvalid + } + interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + if err != nil { + return err + } + } + err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex) + if err == nil { + return nil + } else if E.IsMulti(err, unix.ENOPROTOOPT, unix.EINVAL) { + ifIndexDisabled.Store(true) + } else { + return err + } + } + if interfaceName == "" { + if finder == nil { + return os.ErrInvalid + } + interfaceName, err = finder.InterfaceNameByIndex(interfaceIndex) + if err != nil { + return err + } + } return unix.BindToDevice(int(fd), interfaceName) }) } diff --git a/common/control/bind_other.go b/common/control/bind_other.go index 27d0497b..539ef1cb 100644 --- a/common/control/bind_other.go +++ b/common/control/bind_other.go @@ -4,6 +4,6 @@ package control import "syscall" -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error { return nil } diff --git a/common/control/bind_windows.go b/common/control/bind_windows.go index 5e23bf16..7029c80c 100644 --- a/common/control/bind_windows.go +++ b/common/control/bind_windows.go @@ -2,17 +2,28 @@ package control import ( "encoding/binary" + "os" "syscall" "unsafe" M "github.com/sagernet/sing/common/metadata" ) -func bindToInterface(conn syscall.RawConn, network string, address string, interfaceName string, interfaceIndex int) error { +func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int) error { return Raw(conn, func(fd uintptr) error { + var err error + if interfaceIndex == -1 { + if finder == nil { + return os.ErrInvalid + } + interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + if err != nil { + return err + } + } handle := syscall.Handle(fd) if M.ParseSocksaddr(address).AddrString() == "" { - err := bind4(handle, interfaceIndex) + err = bind4(handle, interfaceIndex) if err != nil { return err } diff --git a/protocol/http/client.go b/protocol/http/client.go index 25351fdf..d2dc5f22 100644 --- a/protocol/http/client.go +++ b/protocol/http/client.go @@ -4,10 +4,11 @@ import ( std_bufio "bufio" "context" "encoding/base64" + "fmt" "net" "net/http" - "net/url" "os" + "strings" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -65,42 +66,52 @@ func (c *Client) DialContext(ctx context.Context, network string, destination M. if err != nil { return nil, err } - request := &http.Request{ - Method: http.MethodConnect, - URL: &url.URL{ - Host: destination.String(), - }, - Header: http.Header{ - "Proxy-Connection": []string{"Keep-Alive"}, - }, - } - if c.path != "" { - err = URLSetPath(request.URL, c.path) - if err != nil { - return nil, err - } + URL := destination.String() + HeaderString := "CONNECT " + URL + " HTTP/1.1\r\n" + tempHeaders := map[string][]string{ + "Host": {destination.String()}, + "User-Agent": {"Go-http-client/1.1"}, + "Proxy-Connection": {"Keep-Alive"}, } + for key, valueList := range c.headers { - request.Header.Set(key, valueList[0]) - for _, value := range valueList[1:] { - request.Header.Add(key, value) - } + tempHeaders[key] = valueList } + + if c.path != "" { + tempHeaders["Path"] = []string{c.path} + } + if c.username != "" { auth := c.username + ":" + c.password - request.Header.Add("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth))) + if _, ok := tempHeaders["Proxy-Authorization"]; ok { + tempHeaders["Proxy-Authorization"][len(tempHeaders["Proxy-Authorization"])] = "Basic " + base64.StdEncoding.EncodeToString([]byte(auth)) + } else { + tempHeaders["Proxy-Authorization"] = []string{"Basic " + base64.StdEncoding.EncodeToString([]byte(auth))} + } + } + for key, valueList := range tempHeaders { + HeaderString += key + ": " + strings.Join(valueList, "; ") + "\r\n" } - err = request.Write(conn) + + HeaderString += "\r\n" + + _, err = fmt.Fprintf(conn, "%s", HeaderString) + if err != nil { conn.Close() return nil, err } + reader := std_bufio.NewReader(conn) - response, err := http.ReadResponse(reader, request) + + response, err := http.ReadResponse(reader, nil) + if err != nil { conn.Close() return nil, err } + if response.StatusCode == http.StatusOK { if reader.Buffered() > 0 { buffer := buf.NewSize(reader.Buffered()) diff --git a/protocol/http/link.go b/protocol/http/link.go index 19cb6cfa..554c9c97 100644 --- a/protocol/http/link.go +++ b/protocol/http/link.go @@ -12,3 +12,6 @@ func ReadRequest(b *bufio.Reader) (req *http.Request, err error) //go:linkname URLSetPath net/url.(*URL).setPath func URLSetPath(u *url.URL, p string) error + +//go:linkname ParseBasicAuth net/http.parseBasicAuth +func ParseBasicAuth(auth string) (username, password string, ok bool)