Skip to content

Commit

Permalink
Improve log output, make stats rather than printing all ips
Browse files Browse the repository at this point in the history
  • Loading branch information
NHAS committed Oct 24, 2024
1 parent 36065f5 commit f5d2a6c
Showing 1 changed file with 130 additions and 54 deletions.
184 changes: 130 additions & 54 deletions internal/client/handlers/tun.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"

Expand Down Expand Up @@ -47,6 +48,53 @@ var (
nicIdsLck sync.Mutex
)

type stat struct {
NICID tcpip.NICID

closed bool

udp struct {
active atomic.Int64
failures atomic.Int64
}

tcp struct {
active atomic.Int64
failures atomic.Int64
}
}

func (s *stat) statsPrinter(l logger.Logger) {

pastTcpActive := s.tcp.active.Load()
pastTcpFail := s.tcp.failures.Load()

pastUdpActive := s.udp.active.Load()
pastUdpFail := s.udp.failures.Load()

for !s.closed {

currentTcpActive := s.tcp.active.Load()
currentTcpFail := s.tcp.failures.Load()

currentUdpActive := s.udp.active.Load()
currentUdpFail := s.udp.failures.Load()

if currentUdpActive != pastUdpActive || currentUdpFail != pastUdpFail || currentTcpActive != pastTcpActive || currentTcpFail != pastTcpFail {
l.Info("TUN NIC %d Stats: TCP streams: %d, TCP failures: %d, UDP connections: %d, UDP failures: %d", uint32(s.NICID), currentTcpActive, currentTcpFail, currentUdpActive, currentUdpFail)

pastTcpActive = currentTcpActive
pastTcpFail = currentTcpFail

pastUdpActive = currentUdpActive
pastUdpFail = currentUdpFail
}

time.Sleep(1 * time.Second)
}

}

func Tun(newChannel ssh.NewChannel, l logger.Logger) {

defer func() {
Expand Down Expand Up @@ -124,6 +172,8 @@ func Tun(newChannel ssh.NewChannel, l logger.Logger) {
}
defer tunnel.Close()

l.Info("New TUN NIC %d created", uint32(NICID))

// Create a new gvisor userland network stack.
ns := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
Expand All @@ -139,7 +189,7 @@ func Tun(newChannel ssh.NewChannel, l logger.Logger) {
})
defer ns.Close()

linkEP, err := NewSSHEndpoint(tunnel)
linkEP, err := NewSSHEndpoint(tunnel, l)
if err != nil {
l.Error("failed to create new SSH endpoint: %s", err)
return
Expand All @@ -157,11 +207,19 @@ func Tun(newChannel ssh.NewChannel, l logger.Logger) {
return
}

var tunStat stat
tunStat.NICID = NICID

go tunStat.statsPrinter(l)
defer func() {
tunStat.closed = true
}()

// Forward TCP connections
tcpHandler := tcp.NewForwarder(ns, 0, 14000, forwardTCP)
tcpHandler := tcp.NewForwarder(ns, 0, 14000, forwardTCP(&tunStat))

// Forward UDP connections
udpHandler := udp.NewForwarder(ns, forwardUDP)
udpHandler := udp.NewForwarder(ns, forwardUDP(&tunStat))

// Register forwarders
ns.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpHandler.HandlePacket)
Expand Down Expand Up @@ -197,78 +255,89 @@ func Tun(newChannel ssh.NewChannel, l logger.Logger) {

ssh.DiscardRequests(req)

l.Info("Tunnel ended")

l.Info("TUN NIC %d ended", uint32(NICID))
}

func forwardUDP(request *udp.ForwarderRequest) {
func forwardUDP(tunstats *stat) func(request *udp.ForwarderRequest) {
return func(request *udp.ForwarderRequest) {
id := request.ID()

id := request.ID()
var wq waiter.Queue
ep, iperr := request.CreateEndpoint(&wq)
if iperr != nil {
tunstats.udp.failures.Add(1)

var wq waiter.Queue
ep, iperr := request.CreateEndpoint(&wq)
if iperr != nil {
log.Println("[+] failed to create endpoint for udp: ", iperr)
log.Println("[+] failed to create endpoint for udp: ", iperr)
return
}

return
}
p, _ := NewUDPProxy(&autoStoppingListener{underlying: gonet.NewUDPConn(&wq, ep)}, func() (net.Conn, error) {
return net.Dial("udp", fmt.Sprintf("%s:%d", id.LocalAddress, id.LocalPort))
})
go func() {

log.Printf("tun [+] %s -> %s:%d/udp\n", id.RemoteAddress, id.LocalAddress, id.LocalPort)
tunstats.udp.active.Add(1)
defer tunstats.udp.active.Add(-1)

p, _ := NewUDPProxy(&autoStoppingListener{underlying: gonet.NewUDPConn(&wq, ep)}, func() (net.Conn, error) {
return net.Dial("udp", fmt.Sprintf("%s:%d", id.LocalAddress, id.LocalPort))
})
go func() {
p.Run()
p.Run()

// note that at this point packets that are sent to the current forwarder session
// will be dropped. We will start processing the packets again when we get a new
// forwarder request.
ep.Close()
}()
// note that at this point packets that are sent to the current forwarder session
// will be dropped. We will start processing the packets again when we get a new
// forwarder request.
ep.Close()
}()
}

}

func forwardTCP(request *tcp.ForwarderRequest) {
func forwardTCP(tunstats *stat) func(request *tcp.ForwarderRequest) {
return func(request *tcp.ForwarderRequest) {
id := request.ID()

id := request.ID()
fwdDst := net.TCPAddr{
IP: net.ParseIP(id.LocalAddress.String()),
Port: int(id.LocalPort),
}

fwdDst := net.TCPAddr{
IP: net.ParseIP(id.LocalAddress.String()),
Port: int(id.LocalPort),
}
outbound, err := net.DialTimeout("tcp", fwdDst.String(), 5*time.Second)
if err != nil {

log.Printf("[+] %s -> %s:%d/tcp\n", id.RemoteAddress, id.LocalAddress, id.LocalPort)
tunstats.tcp.failures.Add(1)

outbound, err := net.DialTimeout("tcp", fwdDst.String(), 5*time.Second)
if err != nil {
log.Printf("failed to dial: %s:%d/tcp", id.LocalAddress, id.LocalPort)
request.Complete(true)
return
}
request.Complete(true)
return
}

var wq waiter.Queue
ep, errTcp := request.CreateEndpoint(&wq)
var wq waiter.Queue
ep, errTcp := request.CreateEndpoint(&wq)

request.Complete(false)
request.Complete(false)

if errTcp != nil {
// ErrConnectionRefused is a transient error
if _, ok := errTcp.(*tcpip.ErrConnectionRefused); !ok {
log.Printf("could not create endpoint: %s", errTcp)
}
tunstats.tcp.failures.Add(1)

if errTcp != nil {
// ErrConnectionRefused is a transient error
if _, ok := errTcp.(*tcpip.ErrConnectionRefused); !ok {
log.Printf("could not create endpoint: %s", errTcp)
return
}
return
}

remote := tcpproxy.DialProxy{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return outbound, nil
},
tunstats.tcp.active.Add(1)
defer tunstats.tcp.active.Add(-1)

remote := tcpproxy.DialProxy{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return outbound, nil
},
}
remote.HandleConn(gonet.NewTCPConn(&wq, ep))
}
remote.HandleConn(gonet.NewTCPConn(&wq, ep))
}

type SSHEndpoint struct {
l logger.Logger

dispatcher stack.NetworkDispatcher
tunnel ssh.Channel

Expand All @@ -284,10 +353,11 @@ type SSHEndpoint struct {
//go:linkname adjustWindow golang.org/x/crypto/ssh.(*channel).adjustWindow
func adjustWindow(c unsafe.Pointer, n uint32) error

func NewSSHEndpoint(dev ssh.Channel) (*SSHEndpoint, error) {
func NewSSHEndpoint(dev ssh.Channel, l logger.Logger) (*SSHEndpoint, error) {

r := &SSHEndpoint{
tunnel: dev,
l: l,
}

const bufferName = "pending"
Expand Down Expand Up @@ -430,7 +500,9 @@ func (m *SSHEndpoint) dispatchLoop() {

packet, err := m.ReadSSHPacket()
if err != nil {
log.Println("failed to read from tunnel: ", err)
if err != io.EOF {
m.l.Error("failed to read from tunnel: %s", err)
}
m.tunnel.Close()
return
}
Expand Down Expand Up @@ -512,7 +584,11 @@ func (m *SSHEndpoint) writePacket(pkt *stack.PacketBuffer) tcpip.Error {
packet = append(packet, pktBuf...)

if _, err := m.tunnel.Write(packet); err != nil {
log.Println("failed to write packet to tunnel: ", err)

if err != io.EOF {
m.l.Error("failed to write packet to tunnel: %s", err)
}

return &tcpip.ErrInvalidEndpointState{}
}

Expand Down

0 comments on commit f5d2a6c

Please sign in to comment.