Skip to content
This repository has been archived by the owner on Nov 19, 2024. It is now read-only.

Commit

Permalink
update: TCP routine is now using go's net standard library for more…
Browse files Browse the repository at this point in the history
… portability
  • Loading branch information
ggmolly committed Jun 15, 2024
1 parent 66352a7 commit 7a133bc
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 163 deletions.
17 changes: 2 additions & 15 deletions connection/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"math/rand"
"net"
"syscall"
"time"

"github.com/ggmolly/belfast/logger"
Expand All @@ -18,13 +17,11 @@ var (
)

type Client struct {
SockAddr syscall.Sockaddr
ProxyFD int // only used in proxy strategy, contains the fd of the proxy client
IP net.IP
Port int
FD int
State int
PacketIndex int
Connection *net.Conn
Commander *orm.Commander
Buffer bytes.Buffer
Server *Server
Expand Down Expand Up @@ -104,19 +101,9 @@ func (client *Client) GetCommander(accountId uint32) error {
return err
}

func (client *Client) Kill() {
if err := syscall.EpollCtl(client.Server.EpollFD, syscall.EPOLL_CTL_DEL, client.FD, nil); err != nil {
logger.LogEvent("Client", "Kill()", fmt.Sprintf("%s:%d -> %v", client.IP, client.Port, err), logger.LOG_LEVEL_ERROR)
}
if err := syscall.Close(client.FD); err != nil {
logger.LogEvent("Client", "Kill()", fmt.Sprintf("%s:%d -> %v", client.IP, client.Port, err), logger.LOG_LEVEL_ERROR)
return
}
}

// Sends the content of the buffer to the client via TCP
func (client *Client) Flush() {
_, err := syscall.Write(client.FD, client.Buffer.Bytes())
_, err := (*client.Connection).Write(client.Buffer.Bytes())
if err != nil {
logger.LogEvent("Client", "Flush()", fmt.Sprintf("%s:%d -> %v", client.IP, client.Port, err), logger.LOG_LEVEL_ERROR)
}
Expand Down
217 changes: 72 additions & 145 deletions connection/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ package connection
import (
"bytes"
"fmt"
"log"
"io"
"net"
"os"
"reflect"
"syscall"
"sync"

"github.com/ggmolly/belfast/debug"
"github.com/ggmolly/belfast/logger"
Expand All @@ -16,14 +16,13 @@ import (
"google.golang.org/protobuf/proto"
)

type ServerDispatcher func(*[]byte, *Client)
type ServerDispatcher func(*[]byte, *Client, int)

type Server struct {
BindAddress string
Port int
SocketFD int
EpollFD int
Clients map[int]*Client
Dispatcher ServerDispatcher
rooms map[uint32][]*Client
Region string
Expand All @@ -33,180 +32,109 @@ var (
BelfastInstance *Server
)

func (server *Server) GetClient(fd int) (*Client, error) {
func (server *Server) GetClient(conn *net.Conn) (*Client, error) {
var client Client
var err error
client.SockAddr, err = syscall.Getpeername(fd)
if err != nil {
return &client, err
}
client.IP = client.SockAddr.(*syscall.SockaddrInet4).Addr[:]
client.Port = client.SockAddr.(*syscall.SockaddrInet4).Port
client.FD = fd
client.IP = (*conn).RemoteAddr().(*net.TCPAddr).IP
client.Port = (*conn).RemoteAddr().(*net.TCPAddr).Port
client.Connection = conn
client.Server = server
return &client, nil
}

func (server *Server) GetConnectedClient(fd int) (*Client, error) {
if client, ok := server.Clients[fd]; ok {
return client, nil
}
return nil, fmt.Errorf("client not found")
return &client, err
}

func (server *Server) AddClient(client *Client) {
logger.LogEvent("Server", "hewwo", fmt.Sprintf("new connection from %s:%d (fd=%d)", client.IP, client.Port, client.FD), logger.LOG_LEVEL_DEBUG)
server.Clients[client.FD] = client
logger.LogEvent("Server", "hewwo", fmt.Sprintf("new connection from %s:%d", client.IP, client.Port), logger.LOG_LEVEL_DEBUG)
}

func (server *Server) RemoveClient(client *Client) {
logger.LogEvent("Server", "cya", fmt.Sprintf("%s:%d (fd=%d)", client.IP, client.Port, client.FD), logger.LOG_LEVEL_DEBUG)
client.Kill()
delete(server.Clients, client.FD)
logger.LogEvent("Server", "cya", fmt.Sprintf("%s:%d", client.IP, client.Port), logger.LOG_LEVEL_DEBUG)
(*client.Connection).Close()
}

func (server *Server) Run() {
var err error
BelfastInstance = server
if server.SocketFD, err = syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM|syscall.O_NONBLOCK, 0); err != nil {
log.Fatalf("failed to create socket : %v", err)
}
defer syscall.Close(server.SocketFD)
logger.LogEvent("Server", "Listen", fmt.Sprintf("Listening on %s:%d", server.BindAddress, server.Port), logger.LOG_LEVEL_AUTO)

if err = syscall.SetsockoptInt(server.SocketFD, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
log.Fatalf("setsockopt error: %v", err)
}
func handleConnection(conn net.Conn, wg *sync.WaitGroup, server *Server) {
defer wg.Done()
defer conn.Close()

if err = syscall.SetNonblock(server.SocketFD, true); err != nil {
log.Fatalf("setnonblock error: %v", err)
}

var ip [4]byte
copy(ip[:], net.ParseIP(server.BindAddress).To4())
addr := syscall.SockaddrInet4{
Port: server.Port,
Addr: ip,
}
// Add the client to the list
client, err := server.GetClient(&conn)

if err = syscall.Bind(server.SocketFD, &addr); err != nil {
log.Fatalf("bind error: %v", err)
if err != nil {
logger.LogEvent("Server", "Handler", fmt.Sprintf("client %s -- error: %v", conn.RemoteAddr(), err), logger.LOG_LEVEL_ERROR)
conn.Close()
server.RemoveClient(client)
return
}

if err = syscall.Listen(server.SocketFD, syscall.SOMAXCONN); err != nil {
log.Fatalf("listen error: %v", err)
if !client.IP.IsPrivate() {
logger.LogEvent("Server", "Handler", fmt.Sprintf("client %s -- not in a private range", conn.RemoteAddr()), logger.LOG_LEVEL_ERROR)
conn.Close()
server.RemoveClient(client)
return
}

if server.EpollFD, err = syscall.EpollCreate1(0); err != nil {
panic(err)
}
server.AddClient(client)

// Prepare epoll (I/O multiplexing)
var event syscall.EpollEvent
event.Events = syscall.EPOLLIN
event.Fd = int32(server.SocketFD)
if err = syscall.EpollCtl(server.EpollFD, syscall.EPOLL_CTL_ADD, server.SocketFD, &event); err != nil {
panic(err)
}
// Buffer for unpacking received data
totalBytes := 0
packerBuffer := make([]byte, 16384)

// Create epoll event buffer
var events [128]syscall.EpollEvent
// Temporary buffer for reading
buffer := make([]byte, 1024)
for {
// Check for events
var nevents int
if nevents, err = syscall.EpollWait(server.EpollFD, events[:], -1); err != nil {
if err == syscall.EINTR {
continue
}
panic(err)
}
var treatedEvents int
for ev := 0; ev < nevents; ev++ {
treatedEvents++
if int(events[ev].Fd) == server.SocketFD {
// Accept new connections
connFd, _, err := syscall.Accept(server.SocketFD)
if err != nil {
logger.LogEvent("Server", "Accept", fmt.Sprintf("accept error: %v", err), logger.LOG_LEVEL_ERROR)
continue
}

// Make the connection non-blocking
if err = syscall.SetNonblock(connFd, true); err != nil {
logger.LogEvent("Server", "SetNonblock", fmt.Sprintf("setnonblock error: %v", err), logger.LOG_LEVEL_ERROR)
syscall.Close(connFd)
continue
}
event.Events = syscall.EPOLLIN
event.Fd = int32(connFd)
if err := syscall.EpollCtl(server.EpollFD, syscall.EPOLL_CTL_ADD, connFd, &event); err != nil {
logger.LogEvent("Server", "EpollCtl", fmt.Sprintf("epoll_ctl error: %v", err), logger.LOG_LEVEL_ERROR)
syscall.Close(connFd)
continue
}
// Add the client to the list
client, err := server.GetClient(connFd)
if err != nil {
logger.LogEvent("Server", "GetClient", fmt.Sprintf("getclient error: %v", err), logger.LOG_LEVEL_ERROR)
continue
}
if !client.IP.IsPrivate() {
logger.LogEvent("Server", "GetClient", fmt.Sprintf("client %s:%d is not in a private range", client.IP, client.Port), logger.LOG_LEVEL_ERROR)
syscall.EpollCtl(server.EpollFD, syscall.EPOLL_CTL_DEL, connFd, &event)
syscall.Close(connFd)
continue
}
server.AddClient(client)
} else {
// Handle data
var buffer = make([]byte, 8192)
clientFd := int(events[ev].Fd)
client, err := server.GetConnectedClient(clientFd)
if err != nil {
logger.LogEvent("Server", "GetConnectedClient", fmt.Sprintf("%v", err), logger.LOG_LEVEL_ERROR)
server.RemoveClient(client)
continue
}
n, err := syscall.Read(clientFd, buffer)
if err != nil { // the client probably closed the connection
logger.LogEvent("Server", "Read", fmt.Sprintf("%v", err), logger.LOG_LEVEL_ERROR)
} else if n > 0 {
buffer = buffer[:n]
if len(buffer) >= 7 {
server.Dispatcher(&buffer, client)
}
} else {
// EOF, delete from epoll
server.RemoveClient(client)
}
}
n, err := conn.Read(buffer)
if err == io.EOF || err != nil {
conn.Close()
server.RemoveClient(client)
break
}
if treatedEvents != nevents {
panic(fmt.Errorf("treated %d events out of %d", treatedEvents, nevents))
// copy the buffer to the packerBuffer
copy(packerBuffer[totalBytes:], buffer[:n])
totalBytes += n

// To know if we have atleast a full message, check first 2 bytes of the packerBuffer
// these two bytes are the length of a message
size := int(packerBuffer[0])<<8 | int(packerBuffer[1]) + 2 // take into account the 2 bytes for the size
if totalBytes >= size {
// We have a full message, slice it and send it to the dispatcher
message := packerBuffer[:size]
server.Dispatcher(&message, client, size)
// Remove the message from the packerBuffer and shift the rest of the buffer
packerBuffer = packerBuffer[size:]
totalBytes -= size
} else {
// Otherwise, wait for more data
continue
}
}
}

func (server *Server) Kill() {
logger.LogEvent("Server", "Kill()", "Closing server", logger.LOG_LEVEL_INFO)
if err := syscall.Close(server.SocketFD); err != nil {
logger.LogEvent("Server", "Kill()", fmt.Sprintf("error closing socket: %v", err), logger.LOG_LEVEL_ERROR)
}
if err := syscall.Close(server.EpollFD); err != nil {
logger.LogEvent("Server", "Kill()", fmt.Sprintf("error closing epoll: %v", err), logger.LOG_LEVEL_ERROR)
func (server *Server) Run() {
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", server.BindAddress, server.Port))
if err != nil {
logger.LogEvent("Server", "Run", fmt.Sprintf("error listening: %v", err), logger.LOG_LEVEL_ERROR)
return
}
// Close all clients
for _, client := range server.Clients {
client.Kill()
defer listener.Close()
logger.LogEvent("Server", "Run", fmt.Sprintf("listening on %s:%d", server.BindAddress, server.Port), logger.LOG_LEVEL_INFO)

var wg sync.WaitGroup
for {
conn, err := listener.Accept()
if err != nil {
logger.LogEvent("Server", "Run", fmt.Sprintf("error accepting: %v", err), logger.LOG_LEVEL_ERROR)
continue
}
wg.Add(1)
go handleConnection(conn, &wg, server)
}
wg.Wait()
}

func NewServer(bindAddress string, port int, dispatcher ServerDispatcher) *Server {
return &Server{
BindAddress: bindAddress,
Port: port,
Dispatcher: dispatcher,
Clients: make(map[int]*Client),
Region: os.Getenv("AL_REGION"),
rooms: make(map[uint32][]*Client),
}
Expand Down Expand Up @@ -251,7 +179,6 @@ func (server *Server) SendMessage(sender *Client, message orm.Message) {
}
}

// TODO: Expose publicly these functions, and delete the package `packets`
func GeneratePacketHeader(packetId int, payload *[]byte, packetIndex int) []byte {
var buffer bytes.Buffer

Expand Down
1 change: 0 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ func main() {
go func() {
<-sigChannel
fmt.Printf("\r") // trick to avoid ^C in the terminal, could use low-level RawMode() but why bother
server.Kill()
os.Exit(0)
}()
// Prepare web server
Expand Down
4 changes: 2 additions & 2 deletions packets/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ func RegisterLocalizedPacketHandler(packetId int, localizedHandler LocalizedHand
}

// Find each packet in the buffer and dispatch it to the appropriate handler.
func Dispatch(buffer *[]byte, client *connection.Client) {
func Dispatch(buffer *[]byte, client *connection.Client, n int) {
offset := 0
for offset < len(*buffer) {
for offset < n {
packetId := GetPacketId(offset, buffer)
packetSize := GetPacketSize(offset, buffer) + 2
client.PacketIndex = GetPacketIndex(offset, buffer)
Expand Down

0 comments on commit 7a133bc

Please sign in to comment.