diff --git a/internal/socks/client.go b/internal/socks/client.go index 3d6f516a5..e96c6e8f5 100644 --- a/internal/socks/client.go +++ b/internal/socks/client.go @@ -18,14 +18,18 @@ var ( aLongTimeAgo = time.Unix(1, 0) ) -func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net.Addr, ctxErr error) { - host, port, err := splitHostPort(address) +func (d *Dialer) connect(ctx context.Context, c net.Conn, req Request) (conn net.Conn, _ net.Addr, ctxErr error) { + var udpHeader []byte + + host, port, err := splitHostPort(req.DstAddress) if err != nil { - return nil, err + return c, nil, err } if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { c.SetDeadline(deadline) - defer c.SetDeadline(noDeadline) + if req.Cmd != CmdUDPAssociate { + defer c.SetDeadline(noDeadline) + } } if ctx != context.Background() { errCh := make(chan error, 1) @@ -47,6 +51,7 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net }() } + conn = c b := make([]byte, 0, 6+len(host)) // the size here is just an estimate b = append(b, Version5) if len(d.AuthMethods) == 0 || d.Authenticate == nil { @@ -54,7 +59,7 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net } else { ams := d.AuthMethods if len(ams) > 255 { - return nil, errors.New("too many authentication methods") + return c, nil, errors.New("too many authentication methods") } b = append(b, byte(len(ams))) for _, am := range ams { @@ -69,11 +74,11 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net return } if b[0] != Version5 { - return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + return c, nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) } am := AuthMethod(b[1]) if am == AuthMethodNoAcceptableMethods { - return nil, errors.New("no acceptable authentication methods") + return c, nil, errors.New("no acceptable authentication methods") } if d.Authenticate != nil { if ctxErr = d.Authenticate(ctx, c, am); ctxErr != nil { @@ -82,7 +87,7 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net } b = b[:0] - b = append(b, Version5, byte(d.cmd), 0) + b = append(b, Version5, byte(req.Cmd), 0) if ip := net.ParseIP(host); ip != nil { if ip4 := ip.To4(); ip4 != nil { b = append(b, AddrTypeIPv4) @@ -91,17 +96,23 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net b = append(b, AddrTypeIPv6) b = append(b, ip6...) } else { - return nil, errors.New("unknown address type") + return c, nil, errors.New("unknown address type") } } else { if len(host) > 255 { - return nil, errors.New("FQDN too long") + return c, nil, errors.New("FQDN too long") } b = append(b, AddrTypeFQDN) b = append(b, byte(len(host))) b = append(b, host...) } b = append(b, byte(port>>8), byte(port)) + + if req.Cmd == CmdUDPAssociate { + udpHeader = make([]byte, len(b)) + copy(udpHeader[3:], b[3:]) + } + if _, ctxErr = c.Write(b); ctxErr != nil { return } @@ -110,17 +121,18 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net return } if b[0] != Version5 { - return nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) + return c, nil, errors.New("unexpected protocol version " + strconv.Itoa(int(b[0]))) } if cmdErr := Reply(b[1]); cmdErr != StatusSucceeded { - return nil, errors.New("unknown error " + cmdErr.String()) + return c, nil, errors.New("unknown error " + cmdErr.String()) } if b[2] != 0 { - return nil, errors.New("non-zero reserved field") + return c, nil, errors.New("non-zero reserved field") } l := 2 + addrType := b[3] var a Addr - switch b[3] { + switch addrType { case AddrTypeIPv4: l += net.IPv4len a.IP = make(net.IP, net.IPv4len) @@ -129,12 +141,13 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net a.IP = make(net.IP, net.IPv6len) case AddrTypeFQDN: if _, err := io.ReadFull(c, b[:1]); err != nil { - return nil, err + return c, nil, err } l += int(b[0]) default: - return nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) + return c, nil, errors.New("unknown address type " + strconv.Itoa(int(b[3]))) } + if cap(b) < l { b = make([]byte, l) } else { @@ -149,20 +162,19 @@ func (d *Dialer) connect(ctx context.Context, c net.Conn, address string) (_ net a.Name = string(b[:len(b)-2]) } a.Port = int(b[len(b)-2])<<8 | int(b[len(b)-1]) - return &a, nil -} -func splitHostPort(address string) (string, int, error) { - host, port, err := net.SplitHostPort(address) - if err != nil { - return "", 0, err - } - portnum, err := strconv.Atoi(port) - if err != nil { - return "", 0, err - } - if 1 > portnum || portnum > 0xffff { - return "", 0, errors.New("port number out of range " + port) + if req.Cmd == CmdUDPAssociate { + var uc net.Conn + if uc, err = d.proxyDial(ctx, req.UDPNetwork, a.String()); err != nil { + return c, &a, err + } + c.SetDeadline(noDeadline) + go func() { + defer uc.Close() + io.Copy(io.Discard, c) + }() + return udpConn{Conn: uc, socksConn: c, header: udpHeader}, &a, nil } - return host, portnum, nil + + return c, &a, nil } diff --git a/internal/socks/dial_test.go b/internal/socks/dial_test.go index 3a7a31bde..5cd73d0b9 100644 --- a/internal/socks/dial_test.go +++ b/internal/socks/dial_test.go @@ -6,6 +6,7 @@ package socks_test import ( "context" + "errors" "io" "math/rand" "net" @@ -15,6 +16,7 @@ import ( "golang.org/x/net/internal/socks" "golang.org/x/net/internal/sockstest" + "golang.org/x/net/nettest" ) func TestDial(t *testing.T) { @@ -33,7 +35,7 @@ func TestDial(t *testing.T) { Username: "username", Password: "password", }).Authenticate - c, err := d.DialContext(context.Background(), ss.TargetAddr().Network(), ss.TargetAddr().String()) + c, err := d.DialContext(context.Background(), "tcp", ss.TargetAddrPort().String()) if err != nil { t.Fatal(err) } @@ -60,7 +62,7 @@ func TestDial(t *testing.T) { Username: "username", Password: "password", }).Authenticate - a, err := d.DialWithConn(context.Background(), c, ss.TargetAddr().Network(), ss.TargetAddr().String()) + a, err := d.DialWithConn(context.Background(), c, "tcp", ss.TargetAddrPort().String()) if err != nil { t.Fatal(err) } @@ -79,7 +81,7 @@ func TestDial(t *testing.T) { defer cancel() dialErr := make(chan error) go func() { - c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String()) + c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String()) if err == nil { c.Close() } @@ -101,7 +103,7 @@ func TestDial(t *testing.T) { d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String()) ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) defer cancel() - c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String()) + c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String()) if err == nil { c.Close() } @@ -119,7 +121,7 @@ func TestDial(t *testing.T) { for i := 0; i < 2*len(rogueCmdList); i++ { ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) defer cancel() - c, err := d.DialContext(ctx, ss.TargetAddr().Network(), ss.TargetAddr().String()) + c, err := d.DialContext(ctx, "tcp", ss.TargetAddrPort().String()) if err == nil { t.Log(c.(*socks.Conn).BoundAddr()) c.Close() @@ -127,6 +129,80 @@ func TestDial(t *testing.T) { } } }) + t.Run("UDPAssociate", func(t *testing.T) { + ss, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String()) + c, err := d.DialContext(context.Background(), "udp", ss.TargetAddrPort().String()) + if err != nil { + t.Fatal(err) + } + c.Close() + if network := c.RemoteAddr().Network(); network != "udp" { + t.Errorf("RemoteAddr().Network(): expected \"udp\" got %q", network) + } + expected := "127.0.0.1:5964" + if remoteAddr := c.RemoteAddr().String(); remoteAddr != expected { + t.Errorf("RemoteAddr(): expected %q got %q", expected, remoteAddr) + } + if boundAddr := c.(*socks.Conn).BoundAddr().String(); boundAddr != expected { + t.Errorf("BoundAddr(): expected %q got %q", expected, boundAddr) + } + }) + t.Run("UDPAssociateWithReadAndWrite", func(t *testing.T) { + rc, cmdFunc, err := packetListenerCmdFunc() + if err != nil { + t.Fatal(err) + } + defer rc.Close() + ss, err := sockstest.NewServer(sockstest.NoAuthRequired, cmdFunc) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + d := socks.NewDialer(ss.Addr().Network(), ss.Addr().String()) + c, err := d.DialContext(context.Background(), "udp", ss.TargetAddrPort().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + buf := make([]byte, 32) + expected := "HELLO OUTBOUND" + n, err := c.Write([]byte(expected)) + if err != nil { + t.Fatal(err) + } + if len(expected) != n { + t.Errorf("Write(): expected %v bytes got %v", len(expected), n) + } + n, addr, err := rc.ReadFrom(buf) + if err != nil { + t.Fatal(err) + } + data, err := socks.SkipUDPHeader(buf[:n]) + if err != nil { + t.Fatal(err) + } + if actual := string(data); expected != actual { + t.Errorf("ReadFrom(): expected %q got %q", expected, actual) + } + udpHeader := []byte{0x00, 0x00, 0x00, 0x01, 0x7f, 0x00, 0x00, 0x01, 0x17, 0x4b} + expected = "HELLO INBOUND" + _, err = rc.WriteTo(append(udpHeader, []byte(expected)...), addr) + if err != nil { + t.Fatal(err) + } + n, err = c.Read(buf) + if err != nil { + t.Fatal(err) + } + if actual := string(buf[:n]); expected != actual { + t.Errorf("Read(): expected %q got %q", expected, actual) + } + }) } func blackholeCmdFunc(rw io.ReadWriter, b []byte) error { @@ -168,3 +244,33 @@ func parseDialError(err error) (perr, nerr error) { perr = err return } + +func packetListenerCmdFunc() (net.PacketConn, func(io.ReadWriter, []byte) error, error) { + conn, err := nettest.NewLocalPacketListener("udp") + if err != nil { + return nil, nil, err + } + localAddr := conn.LocalAddr().(*net.UDPAddr) + return conn, func(rw io.ReadWriter, b []byte) error { + req, err := sockstest.ParseCmdRequest(b) + if err != nil { + return err + } + if req.Cmd != socks.CmdUDPAssociate { + return errors.New("unexpected command") + } + b, err = sockstest.MarshalCmdReply(socks.Version5, socks.StatusSucceeded, &socks.Addr{IP: localAddr.IP, Port: localAddr.Port}) + if err != nil { + return err + } + n, err := rw.Write(b) + if err != nil { + return err + } + if n != len(b) { + return errors.New("short write") + } + _, err = io.Copy(io.Discard, rw) + return err + }, nil +} diff --git a/internal/socks/socks.go b/internal/socks/socks.go index 84fcc32b6..59a907b47 100644 --- a/internal/socks/socks.go +++ b/internal/socks/socks.go @@ -26,6 +26,8 @@ func (cmd Command) String() string { return "socks connect" case cmdBind: return "socks bind" + case CmdUDPAssociate: + return "socks udp associate" default: return "socks " + strconv.Itoa(int(cmd)) } @@ -70,8 +72,9 @@ const ( AddrTypeFQDN = 0x03 AddrTypeIPv6 = 0x04 - CmdConnect Command = 0x01 // establishes an active-open forward proxy connection - cmdBind Command = 0x02 // establishes a passive-open forward proxy connection + CmdConnect Command = 0x01 // establishes an active-open forward proxy connection + cmdBind Command = 0x02 // establishes a passive-open forward proxy connection + CmdUDPAssociate Command = 0x03 // establishes an active-open forward proxy UDP socket AuthMethodNotRequired AuthMethod = 0x00 // no authentication required AuthMethodUsernamePassword AuthMethod = 0x02 // use username/password @@ -101,6 +104,13 @@ func (a *Addr) String() string { return net.JoinHostPort(a.IP.String(), port) } +// Request represents a SOCKS request. +type Request struct { + Cmd Command + DstAddress string + UDPNetwork string +} + // A Conn represents a forward proxy connection. type Conn struct { net.Conn @@ -119,9 +129,8 @@ func (c *Conn) BoundAddr() net.Addr { // A Dialer holds SOCKS-specific options. type Dialer struct { - cmd Command // either CmdConnect or cmdBind - proxyNetwork string // network between a proxy server and a client - proxyAddress string // proxy server address + proxyNetwork string // network between a proxy server and a client + proxyAddress string // proxy server address // ProxyDial specifies the optional dial function for // establishing the transport connection. @@ -149,31 +158,25 @@ type Dialer struct { // See func Dial of the net package of standard library for a // description of the network and address parameters. func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - if err := d.validateTarget(network, address); err != nil { + req, err := d.newRequest(network, address) + if err != nil { proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + return nil, &net.OpError{Op: req.Cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} } if ctx == nil { proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} - } - var err error - var c net.Conn - if d.ProxyDial != nil { - c, err = d.ProxyDial(ctx, d.proxyNetwork, d.proxyAddress) - } else { - var dd net.Dialer - c, err = dd.DialContext(ctx, d.proxyNetwork, d.proxyAddress) + return nil, &net.OpError{Op: req.Cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} } + c, err := d.proxyDial(ctx, d.proxyNetwork, d.proxyAddress) if err != nil { proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + return nil, &net.OpError{Op: req.Cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} } - a, err := d.connect(ctx, c, address) + c, a, err := d.connect(ctx, c, req) if err != nil { c.Close() proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + return nil, &net.OpError{Op: req.Cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} } return &Conn{Conn: c, boundAddr: a}, nil } @@ -185,18 +188,19 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. // It returns the connection's local address assigned by the SOCKS // server. func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address string) (net.Addr, error) { - if err := d.validateTarget(network, address); err != nil { + req, err := d.newRequest(network, address) + if err != nil { proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + return nil, &net.OpError{Op: req.Cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} } if ctx == nil { proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} + return nil, &net.OpError{Op: req.Cmd.String(), Net: network, Source: proxy, Addr: dst, Err: errors.New("nil context")} } - a, err := d.connect(ctx, c, address) + _, a, err := d.connect(ctx, c, req) if err != nil { proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + return nil, &net.OpError{Op: req.Cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} } return a, nil } @@ -208,40 +212,33 @@ func (d *Dialer) DialWithConn(ctx context.Context, c net.Conn, network, address // // Deprecated: Use DialContext or DialWithConn instead. func (d *Dialer) Dial(network, address string) (net.Conn, error) { - if err := d.validateTarget(network, address); err != nil { + req, err := d.newRequest(network, address) + if err != nil { proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} - } - var err error - var c net.Conn - if d.ProxyDial != nil { - c, err = d.ProxyDial(context.Background(), d.proxyNetwork, d.proxyAddress) - } else { - c, err = net.Dial(d.proxyNetwork, d.proxyAddress) + return nil, &net.OpError{Op: req.Cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} } + c, err := d.proxyDial(context.Background(), d.proxyNetwork, d.proxyAddress) if err != nil { proxy, dst, _ := d.pathAddrs(address) - return nil, &net.OpError{Op: d.cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} + return nil, &net.OpError{Op: req.Cmd.String(), Net: network, Source: proxy, Addr: dst, Err: err} } - if _, err := d.DialWithConn(context.Background(), c, network, address); err != nil { + c, _, err = d.connect(context.Background(), c, req) + if err != nil { c.Close() return nil, err } return c, nil } -func (d *Dialer) validateTarget(network, address string) error { +func (d *Dialer) newRequest(network, address string) (Request, error) { switch network { case "tcp", "tcp6", "tcp4": + return Request{Cmd: CmdConnect, DstAddress: address}, nil + case "udp", "udp6", "udp4": + return Request{Cmd: CmdUDPAssociate, DstAddress: address, UDPNetwork: network}, nil default: - return errors.New("network not implemented") - } - switch d.cmd { - case CmdConnect, cmdBind: - default: - return errors.New("command not implemented") + return Request{Cmd: CmdConnect, DstAddress: address}, errors.New("network not implemented") } - return nil } func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) { @@ -264,10 +261,19 @@ func (d *Dialer) pathAddrs(address string) (proxy, dst net.Addr, err error) { return } +func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Conn, error) { + if d.ProxyDial != nil { + return d.ProxyDial(ctx, network, address) + } else { + var dd net.Dialer + return dd.DialContext(ctx, network, address) + } +} + // NewDialer returns a new Dialer that dials through the provided // proxy server's network and address. func NewDialer(network, address string) *Dialer { - return &Dialer{proxyNetwork: network, proxyAddress: address, cmd: CmdConnect} + return &Dialer{proxyNetwork: network, proxyAddress: address} } const ( @@ -315,3 +321,18 @@ func (up *UsernamePassword) Authenticate(ctx context.Context, rw io.ReadWriter, } return errors.New("unsupported authentication method " + strconv.Itoa(int(auth))) } + +func splitHostPort(address string) (string, int, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return "", 0, err + } + portnum, err := strconv.Atoi(port) + if err != nil { + return "", 0, err + } + if 1 > portnum || portnum > 0xffff { + return "", 0, errors.New("port number out of range " + port) + } + return host, portnum, nil +} diff --git a/internal/socks/udpconn.go b/internal/socks/udpconn.go new file mode 100644 index 000000000..6fb92fe5a --- /dev/null +++ b/internal/socks/udpconn.go @@ -0,0 +1,86 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package socks + +import ( + "errors" + "net" + "strconv" +) + +var ( + errHeaderTooSmall = errors.New("packet header is too small") + errFragNotImplemented = errors.New("packet fragmentation is not implemented") +) + +type udpConn struct { + net.Conn + socksConn net.Conn + header []byte +} + +func (udp udpConn) Close() error { + defer udp.socksConn.Close() + return udp.Conn.Close() +} + +func (udp udpConn) Read(b []byte) (int, error) { + buf := make([]byte, 262+len(b)) + n, err := udp.Conn.Read(buf) + buf, hdrErr := SkipUDPHeader(buf[:n]) + if hdrErr != nil { + if err == nil { + err = hdrErr + } + return 0, err + } + n = copy(b, buf) + return n, err +} + +func (udp udpConn) Write(b []byte) (int, error) { + buf := make([]byte, len(udp.header)+len(b)) + n := copy(buf, udp.header) + copy(buf[n:], b) + n, err := udp.Conn.Write(buf) + if n >= len(udp.header) { + n -= len(udp.header) + } else { + n = 0 + } + return n, err +} + +func SkipUDPHeader(buf []byte) ([]byte, error) { + if len(buf) < 4 { + return nil, errHeaderTooSmall + } + frag := buf[2] + addrType := buf[3] + buf = buf[4:] + switch addrType { + case AddrTypeIPv4: + if len(buf) < net.IPv4len+2 { + return nil, errHeaderTooSmall + } + buf = buf[net.IPv4len+2:] + case AddrTypeIPv6: + if len(buf) < net.IPv6len+2 { + return nil, errHeaderTooSmall + } + buf = buf[net.IPv6len+2:] + case AddrTypeFQDN: + if len(buf) == 0 || len(buf) < 1+int(buf[0])+2 { + return nil, errHeaderTooSmall + } + buf = buf[1+int(buf[0])+2:] + default: + return nil, errors.New("unknown address type " + strconv.Itoa(int(addrType))) + } + if frag != 0 { + return nil, errFragNotImplemented + } + return buf, nil +} diff --git a/internal/sockstest/server.go b/internal/sockstest/server.go index c25a82f12..2b55deeaf 100644 --- a/internal/sockstest/server.go +++ b/internal/sockstest/server.go @@ -9,6 +9,7 @@ import ( "errors" "io" "net" + "net/netip" "golang.org/x/net/internal/socks" "golang.org/x/net/nettest" @@ -61,8 +62,11 @@ func ParseCmdRequest(b []byte) (*CmdRequest, error) { if b[0] != socks.Version5 { return nil, errors.New("unexpected protocol version") } - if socks.Command(b[1]) != socks.CmdConnect { + switch socks.Command(b[1]) { + case socks.CmdConnect, socks.CmdUDPAssociate: + default: return nil, errors.New("unexpected command") + } if b[2] != 0 { return nil, errors.New("non-zero reserved field") @@ -130,21 +134,20 @@ func (s *Server) Addr() net.Addr { return s.ln.Addr() } -// TargetAddr returns a fake final destination address. +// TargetAddrPort returns a fake final destination address and port. // // The returned address is only valid for testing with Server. -func (s *Server) TargetAddr() net.Addr { +func (s *Server) TargetAddrPort() netip.AddrPort { a := s.ln.Addr() switch a := a.(type) { case *net.TCPAddr: - if a.IP.To4() != nil { - return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 5963} - } - if a.IP.To16() != nil && a.IP.To4() == nil { - return &net.TCPAddr{IP: net.IPv6loopback, Port: 5963} + addr, ok := netip.AddrFromSlice(a.IP) + if !ok { + return netip.AddrPort{} } + return netip.AddrPortFrom(addr, 5963) } - return nil + return netip.AddrPort{} } // Close closes the server. diff --git a/internal/sockstest/server_test.go b/internal/sockstest/server_test.go index 2b02d8161..946e9eabf 100644 --- a/internal/sockstest/server_test.go +++ b/internal/sockstest/server_test.go @@ -74,6 +74,17 @@ func TestParseCmdRequest(t *testing.T) { }, }, }, + { + []byte{0x05, 0x03, 0x00, 0x01, 192, 0, 2, 1, 0x17, 0x4b}, + &CmdRequest{ + socks.Version5, + socks.CmdUDPAssociate, + socks.Addr{ + IP: net.IP{192, 0, 2, 1}, + Port: 5963, + }, + }, + }, { []byte{0x05, 0x01, 0x00, 0x03, 0x04, 'F', 'Q', 'D', 'N', 0x17, 0x4b}, &CmdRequest{ diff --git a/proxy/dial_test.go b/proxy/dial_test.go index 608835b5c..3ca9b8d3d 100644 --- a/proxy/dial_test.go +++ b/proxy/dial_test.go @@ -73,7 +73,7 @@ func TestDial(t *testing.T) { if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil { t.Fatal(err) } - c, err := Dial(context.Background(), s.TargetAddr().Network(), s.TargetAddr().String()) + c, err := Dial(context.Background(), "tcp", s.TargetAddrPort().String()) if err != nil { t.Fatal(err) } @@ -91,7 +91,7 @@ func TestDial(t *testing.T) { } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String()) + c, err := Dial(ctx, "tcp", s.TargetAddrPort().String()) if err != nil { t.Fatal(err) } @@ -110,10 +110,26 @@ func TestDial(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond) time.Sleep(time.Millisecond) defer cancel() - c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String()) + c, err := Dial(ctx, "tcp", s.TargetAddrPort().String()) if err == nil { defer c.Close() t.Fatal("failed to timeout") } }) + t.Run("SOCKS5WithUDP", func(t *testing.T) { + defer ResetProxyEnv() + s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired) + if err != nil { + t.Fatal(err) + } + defer s.Close() + if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil { + t.Fatal(err) + } + c, err := Dial(context.Background(), "udp", s.TargetAddrPort().String()) + if err != nil { + t.Fatal(err) + } + c.Close() + }) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 567fc9c36..f1b471b8e 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -106,7 +106,7 @@ func TestSOCKS5(t *testing.T) { if err != nil { t.Fatal(err) } - c, err := proxy.Dial("tcp", ss.TargetAddr().String()) + c, err := proxy.Dial("tcp", ss.TargetAddrPort().String()) if err != nil { t.Fatal(err) }