From 7a133bc399d000757e07a6a6edb66640edcef774 Mon Sep 17 00:00:00 2001 From: ggmolly Date: Sat, 15 Jun 2024 14:20:02 +0200 Subject: [PATCH] update: TCP routine is now using go's `net` standard library for more portability --- connection/client.go | 17 +--- connection/server.go | 217 ++++++++++++++----------------------------- main.go | 1 - packets/handler.go | 4 +- 4 files changed, 76 insertions(+), 163 deletions(-) diff --git a/connection/client.go b/connection/client.go index 792e7ce..ce57375 100644 --- a/connection/client.go +++ b/connection/client.go @@ -5,7 +5,6 @@ import ( "fmt" "math/rand" "net" - "syscall" "time" "github.com/ggmolly/belfast/logger" @@ -18,13 +17,11 @@ var ( ) type Client struct { - SockAddr syscall.Sockaddr - ProxyFD int // only used in proxy strategy, contains the fd of the proxy client IP net.IP Port int - FD int State int PacketIndex int + Connection *net.Conn Commander *orm.Commander Buffer bytes.Buffer Server *Server @@ -104,19 +101,9 @@ func (client *Client) GetCommander(accountId uint32) error { return err } -func (client *Client) Kill() { - if err := syscall.EpollCtl(client.Server.EpollFD, syscall.EPOLL_CTL_DEL, client.FD, nil); err != nil { - logger.LogEvent("Client", "Kill()", fmt.Sprintf("%s:%d -> %v", client.IP, client.Port, err), logger.LOG_LEVEL_ERROR) - } - if err := syscall.Close(client.FD); err != nil { - logger.LogEvent("Client", "Kill()", fmt.Sprintf("%s:%d -> %v", client.IP, client.Port, err), logger.LOG_LEVEL_ERROR) - return - } -} - // Sends the content of the buffer to the client via TCP func (client *Client) Flush() { - _, err := syscall.Write(client.FD, client.Buffer.Bytes()) + _, err := (*client.Connection).Write(client.Buffer.Bytes()) if err != nil { logger.LogEvent("Client", "Flush()", fmt.Sprintf("%s:%d -> %v", client.IP, client.Port, err), logger.LOG_LEVEL_ERROR) } diff --git a/connection/server.go b/connection/server.go index 3540ca2..04f51f5 100644 --- a/connection/server.go +++ b/connection/server.go @@ -3,11 +3,11 @@ package connection import ( "bytes" "fmt" - "log" + "io" "net" "os" "reflect" - "syscall" + "sync" "github.com/ggmolly/belfast/debug" "github.com/ggmolly/belfast/logger" @@ -16,14 +16,13 @@ import ( "google.golang.org/protobuf/proto" ) -type ServerDispatcher func(*[]byte, *Client) +type ServerDispatcher func(*[]byte, *Client, int) type Server struct { BindAddress string Port int SocketFD int EpollFD int - Clients map[int]*Client Dispatcher ServerDispatcher rooms map[uint32][]*Client Region string @@ -33,172 +32,102 @@ var ( BelfastInstance *Server ) -func (server *Server) GetClient(fd int) (*Client, error) { +func (server *Server) GetClient(conn *net.Conn) (*Client, error) { var client Client var err error - client.SockAddr, err = syscall.Getpeername(fd) - if err != nil { - return &client, err - } - client.IP = client.SockAddr.(*syscall.SockaddrInet4).Addr[:] - client.Port = client.SockAddr.(*syscall.SockaddrInet4).Port - client.FD = fd + client.IP = (*conn).RemoteAddr().(*net.TCPAddr).IP + client.Port = (*conn).RemoteAddr().(*net.TCPAddr).Port + client.Connection = conn client.Server = server - return &client, nil -} - -func (server *Server) GetConnectedClient(fd int) (*Client, error) { - if client, ok := server.Clients[fd]; ok { - return client, nil - } - return nil, fmt.Errorf("client not found") + return &client, err } func (server *Server) AddClient(client *Client) { - logger.LogEvent("Server", "hewwo", fmt.Sprintf("new connection from %s:%d (fd=%d)", client.IP, client.Port, client.FD), logger.LOG_LEVEL_DEBUG) - server.Clients[client.FD] = client + logger.LogEvent("Server", "hewwo", fmt.Sprintf("new connection from %s:%d", client.IP, client.Port), logger.LOG_LEVEL_DEBUG) } func (server *Server) RemoveClient(client *Client) { - logger.LogEvent("Server", "cya", fmt.Sprintf("%s:%d (fd=%d)", client.IP, client.Port, client.FD), logger.LOG_LEVEL_DEBUG) - client.Kill() - delete(server.Clients, client.FD) + logger.LogEvent("Server", "cya", fmt.Sprintf("%s:%d", client.IP, client.Port), logger.LOG_LEVEL_DEBUG) + (*client.Connection).Close() } -func (server *Server) Run() { - var err error - BelfastInstance = server - if server.SocketFD, err = syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM|syscall.O_NONBLOCK, 0); err != nil { - log.Fatalf("failed to create socket : %v", err) - } - defer syscall.Close(server.SocketFD) - logger.LogEvent("Server", "Listen", fmt.Sprintf("Listening on %s:%d", server.BindAddress, server.Port), logger.LOG_LEVEL_AUTO) - - if err = syscall.SetsockoptInt(server.SocketFD, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { - log.Fatalf("setsockopt error: %v", err) - } +func handleConnection(conn net.Conn, wg *sync.WaitGroup, server *Server) { + defer wg.Done() + defer conn.Close() - if err = syscall.SetNonblock(server.SocketFD, true); err != nil { - log.Fatalf("setnonblock error: %v", err) - } - - var ip [4]byte - copy(ip[:], net.ParseIP(server.BindAddress).To4()) - addr := syscall.SockaddrInet4{ - Port: server.Port, - Addr: ip, - } + // Add the client to the list + client, err := server.GetClient(&conn) - if err = syscall.Bind(server.SocketFD, &addr); err != nil { - log.Fatalf("bind error: %v", err) + if err != nil { + logger.LogEvent("Server", "Handler", fmt.Sprintf("client %s -- error: %v", conn.RemoteAddr(), err), logger.LOG_LEVEL_ERROR) + conn.Close() + server.RemoveClient(client) + return } - if err = syscall.Listen(server.SocketFD, syscall.SOMAXCONN); err != nil { - log.Fatalf("listen error: %v", err) + if !client.IP.IsPrivate() { + logger.LogEvent("Server", "Handler", fmt.Sprintf("client %s -- not in a private range", conn.RemoteAddr()), logger.LOG_LEVEL_ERROR) + conn.Close() + server.RemoveClient(client) + return } - if server.EpollFD, err = syscall.EpollCreate1(0); err != nil { - panic(err) - } + server.AddClient(client) - // Prepare epoll (I/O multiplexing) - var event syscall.EpollEvent - event.Events = syscall.EPOLLIN - event.Fd = int32(server.SocketFD) - if err = syscall.EpollCtl(server.EpollFD, syscall.EPOLL_CTL_ADD, server.SocketFD, &event); err != nil { - panic(err) - } + // Buffer for unpacking received data + totalBytes := 0 + packerBuffer := make([]byte, 16384) - // Create epoll event buffer - var events [128]syscall.EpollEvent + // Temporary buffer for reading + buffer := make([]byte, 1024) for { - // Check for events - var nevents int - if nevents, err = syscall.EpollWait(server.EpollFD, events[:], -1); err != nil { - if err == syscall.EINTR { - continue - } - panic(err) - } - var treatedEvents int - for ev := 0; ev < nevents; ev++ { - treatedEvents++ - if int(events[ev].Fd) == server.SocketFD { - // Accept new connections - connFd, _, err := syscall.Accept(server.SocketFD) - if err != nil { - logger.LogEvent("Server", "Accept", fmt.Sprintf("accept error: %v", err), logger.LOG_LEVEL_ERROR) - continue - } - - // Make the connection non-blocking - if err = syscall.SetNonblock(connFd, true); err != nil { - logger.LogEvent("Server", "SetNonblock", fmt.Sprintf("setnonblock error: %v", err), logger.LOG_LEVEL_ERROR) - syscall.Close(connFd) - continue - } - event.Events = syscall.EPOLLIN - event.Fd = int32(connFd) - if err := syscall.EpollCtl(server.EpollFD, syscall.EPOLL_CTL_ADD, connFd, &event); err != nil { - logger.LogEvent("Server", "EpollCtl", fmt.Sprintf("epoll_ctl error: %v", err), logger.LOG_LEVEL_ERROR) - syscall.Close(connFd) - continue - } - // Add the client to the list - client, err := server.GetClient(connFd) - if err != nil { - logger.LogEvent("Server", "GetClient", fmt.Sprintf("getclient error: %v", err), logger.LOG_LEVEL_ERROR) - continue - } - if !client.IP.IsPrivate() { - logger.LogEvent("Server", "GetClient", fmt.Sprintf("client %s:%d is not in a private range", client.IP, client.Port), logger.LOG_LEVEL_ERROR) - syscall.EpollCtl(server.EpollFD, syscall.EPOLL_CTL_DEL, connFd, &event) - syscall.Close(connFd) - continue - } - server.AddClient(client) - } else { - // Handle data - var buffer = make([]byte, 8192) - clientFd := int(events[ev].Fd) - client, err := server.GetConnectedClient(clientFd) - if err != nil { - logger.LogEvent("Server", "GetConnectedClient", fmt.Sprintf("%v", err), logger.LOG_LEVEL_ERROR) - server.RemoveClient(client) - continue - } - n, err := syscall.Read(clientFd, buffer) - if err != nil { // the client probably closed the connection - logger.LogEvent("Server", "Read", fmt.Sprintf("%v", err), logger.LOG_LEVEL_ERROR) - } else if n > 0 { - buffer = buffer[:n] - if len(buffer) >= 7 { - server.Dispatcher(&buffer, client) - } - } else { - // EOF, delete from epoll - server.RemoveClient(client) - } - } + n, err := conn.Read(buffer) + if err == io.EOF || err != nil { + conn.Close() + server.RemoveClient(client) + break } - if treatedEvents != nevents { - panic(fmt.Errorf("treated %d events out of %d", treatedEvents, nevents)) + // copy the buffer to the packerBuffer + copy(packerBuffer[totalBytes:], buffer[:n]) + totalBytes += n + + // To know if we have atleast a full message, check first 2 bytes of the packerBuffer + // these two bytes are the length of a message + size := int(packerBuffer[0])<<8 | int(packerBuffer[1]) + 2 // take into account the 2 bytes for the size + if totalBytes >= size { + // We have a full message, slice it and send it to the dispatcher + message := packerBuffer[:size] + server.Dispatcher(&message, client, size) + // Remove the message from the packerBuffer and shift the rest of the buffer + packerBuffer = packerBuffer[size:] + totalBytes -= size + } else { + // Otherwise, wait for more data + continue } } } -func (server *Server) Kill() { - logger.LogEvent("Server", "Kill()", "Closing server", logger.LOG_LEVEL_INFO) - if err := syscall.Close(server.SocketFD); err != nil { - logger.LogEvent("Server", "Kill()", fmt.Sprintf("error closing socket: %v", err), logger.LOG_LEVEL_ERROR) - } - if err := syscall.Close(server.EpollFD); err != nil { - logger.LogEvent("Server", "Kill()", fmt.Sprintf("error closing epoll: %v", err), logger.LOG_LEVEL_ERROR) +func (server *Server) Run() { + listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", server.BindAddress, server.Port)) + if err != nil { + logger.LogEvent("Server", "Run", fmt.Sprintf("error listening: %v", err), logger.LOG_LEVEL_ERROR) + return } - // Close all clients - for _, client := range server.Clients { - client.Kill() + defer listener.Close() + logger.LogEvent("Server", "Run", fmt.Sprintf("listening on %s:%d", server.BindAddress, server.Port), logger.LOG_LEVEL_INFO) + + var wg sync.WaitGroup + for { + conn, err := listener.Accept() + if err != nil { + logger.LogEvent("Server", "Run", fmt.Sprintf("error accepting: %v", err), logger.LOG_LEVEL_ERROR) + continue + } + wg.Add(1) + go handleConnection(conn, &wg, server) } + wg.Wait() } func NewServer(bindAddress string, port int, dispatcher ServerDispatcher) *Server { @@ -206,7 +135,6 @@ func NewServer(bindAddress string, port int, dispatcher ServerDispatcher) *Serve BindAddress: bindAddress, Port: port, Dispatcher: dispatcher, - Clients: make(map[int]*Client), Region: os.Getenv("AL_REGION"), rooms: make(map[uint32][]*Client), } @@ -251,7 +179,6 @@ func (server *Server) SendMessage(sender *Client, message orm.Message) { } } -// TODO: Expose publicly these functions, and delete the package `packets` func GeneratePacketHeader(packetId int, payload *[]byte, packetIndex int) []byte { var buffer bytes.Buffer diff --git a/main.go b/main.go index bd5b6fc..37a8d1d 100644 --- a/main.go +++ b/main.go @@ -49,7 +49,6 @@ func main() { go func() { <-sigChannel fmt.Printf("\r") // trick to avoid ^C in the terminal, could use low-level RawMode() but why bother - server.Kill() os.Exit(0) }() // Prepare web server diff --git a/packets/handler.go b/packets/handler.go index 44efa01..4f0cefa 100644 --- a/packets/handler.go +++ b/packets/handler.go @@ -63,9 +63,9 @@ func RegisterLocalizedPacketHandler(packetId int, localizedHandler LocalizedHand } // Find each packet in the buffer and dispatch it to the appropriate handler. -func Dispatch(buffer *[]byte, client *connection.Client) { +func Dispatch(buffer *[]byte, client *connection.Client, n int) { offset := 0 - for offset < len(*buffer) { + for offset < n { packetId := GetPacketId(offset, buffer) packetSize := GetPacketSize(offset, buffer) + 2 client.PacketIndex = GetPacketIndex(offset, buffer)