From e28784657ff7c4d427e79a7c60a10680032e2931 Mon Sep 17 00:00:00 2001 From: ggmolly Date: Sat, 15 Jun 2024 14:21:14 +0200 Subject: [PATCH 1/6] update: TCP routine is now using go's `net` standard library for more portability --- answer/servers.go | 4 +- connection/client.go | 17 +--- connection/server.go | 216 +++++++++++++++---------------------------- main.go | 2 - packets/handler.go | 4 +- 5 files changed, 77 insertions(+), 166 deletions(-) diff --git a/answer/servers.go b/answer/servers.go index 2c40381..64d0f70 100644 --- a/answer/servers.go +++ b/answer/servers.go @@ -3,7 +3,6 @@ package answer import ( "bytes" "encoding/json" - "syscall" "github.com/ggmolly/belfast/connection" "github.com/ggmolly/belfast/protobuf" @@ -36,8 +35,7 @@ func Forge_SC8239(buffer *[]byte, client *connection.Client) (int, int, error) { } answerBuffer.Write(jsonData) - // Write buffer to fd - n, err := syscall.Write(client.FD, answerBuffer.Bytes()) + n, err := (*client.Connection).Write(answerBuffer.Bytes()) if err != nil { return 0, packetId, err } diff --git a/connection/client.go b/connection/client.go index 792e7ce..ce57375 100644 --- a/connection/client.go +++ b/connection/client.go @@ -5,7 +5,6 @@ import ( "fmt" "math/rand" "net" - "syscall" "time" "github.com/ggmolly/belfast/logger" @@ -18,13 +17,11 @@ var ( ) type Client struct { - SockAddr syscall.Sockaddr - ProxyFD int // only used in proxy strategy, contains the fd of the proxy client IP net.IP Port int - FD int State int PacketIndex int + Connection *net.Conn Commander *orm.Commander Buffer bytes.Buffer Server *Server @@ -104,19 +101,9 @@ func (client *Client) GetCommander(accountId uint32) error { return err } -func (client *Client) Kill() { - if err := syscall.EpollCtl(client.Server.EpollFD, syscall.EPOLL_CTL_DEL, client.FD, nil); err != nil { - logger.LogEvent("Client", "Kill()", fmt.Sprintf("%s:%d -> %v", client.IP, client.Port, err), logger.LOG_LEVEL_ERROR) - } - if err := syscall.Close(client.FD); err != nil { - logger.LogEvent("Client", "Kill()", fmt.Sprintf("%s:%d -> %v", client.IP, client.Port, err), logger.LOG_LEVEL_ERROR) - return - } -} - // Sends the content of the buffer to the client via TCP func (client *Client) Flush() { - _, err := syscall.Write(client.FD, client.Buffer.Bytes()) + _, err := (*client.Connection).Write(client.Buffer.Bytes()) if err != nil { logger.LogEvent("Client", "Flush()", fmt.Sprintf("%s:%d -> %v", client.IP, client.Port, err), logger.LOG_LEVEL_ERROR) } diff --git a/connection/server.go b/connection/server.go index dfdc722..04f51f5 100644 --- a/connection/server.go +++ b/connection/server.go @@ -3,10 +3,11 @@ package connection import ( "bytes" "fmt" + "io" "net" "os" "reflect" - "syscall" + "sync" "github.com/ggmolly/belfast/debug" "github.com/ggmolly/belfast/logger" @@ -15,14 +16,13 @@ import ( "google.golang.org/protobuf/proto" ) -type ServerDispatcher func(*[]byte, *Client) +type ServerDispatcher func(*[]byte, *Client, int) type Server struct { BindAddress string Port int SocketFD int EpollFD int - Clients map[int]*Client Dispatcher ServerDispatcher rooms map[uint32][]*Client Region string @@ -32,172 +32,102 @@ var ( BelfastInstance *Server ) -func (server *Server) GetClient(fd int) (*Client, error) { +func (server *Server) GetClient(conn *net.Conn) (*Client, error) { var client Client var err error - client.SockAddr, err = syscall.Getpeername(fd) - if err != nil { - return &client, err - } - client.IP = client.SockAddr.(*syscall.SockaddrInet4).Addr[:] - client.Port = client.SockAddr.(*syscall.SockaddrInet4).Port - client.FD = fd + client.IP = (*conn).RemoteAddr().(*net.TCPAddr).IP + client.Port = (*conn).RemoteAddr().(*net.TCPAddr).Port + client.Connection = conn client.Server = server - return &client, nil -} - -func (server *Server) GetConnectedClient(fd int) (*Client, error) { - if client, ok := server.Clients[fd]; ok { - return client, nil - } - return nil, fmt.Errorf("client not found") + return &client, err } func (server *Server) AddClient(client *Client) { - logger.LogEvent("Server", "hewwo", fmt.Sprintf("new connection from %s:%d (fd=%d)", client.IP, client.Port, client.FD), logger.LOG_LEVEL_DEBUG) - server.Clients[client.FD] = client + logger.LogEvent("Server", "hewwo", fmt.Sprintf("new connection from %s:%d", client.IP, client.Port), logger.LOG_LEVEL_DEBUG) } func (server *Server) RemoveClient(client *Client) { - logger.LogEvent("Server", "cya", fmt.Sprintf("%s:%d (fd=%d)", client.IP, client.Port, client.FD), logger.LOG_LEVEL_DEBUG) - client.Kill() - delete(server.Clients, client.FD) + logger.LogEvent("Server", "cya", fmt.Sprintf("%s:%d", client.IP, client.Port), logger.LOG_LEVEL_DEBUG) + (*client.Connection).Close() } -func (server *Server) Run() error { - var err error - BelfastInstance = server - if server.SocketFD, err = syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM|syscall.O_NONBLOCK, 0); err != nil { - return fmt.Errorf("failed to create socket : %v", err) - } - defer syscall.Close(server.SocketFD) - logger.LogEvent("Server", "Listen", fmt.Sprintf("Listening on %s:%d", server.BindAddress, server.Port), logger.LOG_LEVEL_AUTO) - - if err = syscall.SetsockoptInt(server.SocketFD, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { - return fmt.Errorf("setsockopt error: %v", err) - } - - if err = syscall.SetNonblock(server.SocketFD, true); err != nil { - return fmt.Errorf("setnonblock error: %v", err) - } +func handleConnection(conn net.Conn, wg *sync.WaitGroup, server *Server) { + defer wg.Done() + defer conn.Close() - var ip [4]byte - copy(ip[:], net.ParseIP(server.BindAddress).To4()) - addr := syscall.SockaddrInet4{ - Port: server.Port, - Addr: ip, - } + // Add the client to the list + client, err := server.GetClient(&conn) - if err = syscall.Bind(server.SocketFD, &addr); err != nil { - return fmt.Errorf("bind error: %v", err) + if err != nil { + logger.LogEvent("Server", "Handler", fmt.Sprintf("client %s -- error: %v", conn.RemoteAddr(), err), logger.LOG_LEVEL_ERROR) + conn.Close() + server.RemoveClient(client) + return } - if err = syscall.Listen(server.SocketFD, syscall.SOMAXCONN); err != nil { - return fmt.Errorf("listen error: %v", err) + if !client.IP.IsPrivate() { + logger.LogEvent("Server", "Handler", fmt.Sprintf("client %s -- not in a private range", conn.RemoteAddr()), logger.LOG_LEVEL_ERROR) + conn.Close() + server.RemoveClient(client) + return } - if server.EpollFD, err = syscall.EpollCreate1(0); err != nil { - panic(err) - } + server.AddClient(client) - // Prepare epoll (I/O multiplexing) - var event syscall.EpollEvent - event.Events = syscall.EPOLLIN - event.Fd = int32(server.SocketFD) - if err = syscall.EpollCtl(server.EpollFD, syscall.EPOLL_CTL_ADD, server.SocketFD, &event); err != nil { - panic(err) - } + // Buffer for unpacking received data + totalBytes := 0 + packerBuffer := make([]byte, 16384) - // Create epoll event buffer - var events [128]syscall.EpollEvent + // Temporary buffer for reading + buffer := make([]byte, 1024) for { - // Check for events - var nevents int - if nevents, err = syscall.EpollWait(server.EpollFD, events[:], -1); err != nil { - if err == syscall.EINTR { - continue - } - panic(err) - } - var treatedEvents int - for ev := 0; ev < nevents; ev++ { - treatedEvents++ - if int(events[ev].Fd) == server.SocketFD { - // Accept new connections - connFd, _, err := syscall.Accept(server.SocketFD) - if err != nil { - logger.LogEvent("Server", "Accept", fmt.Sprintf("accept error: %v", err), logger.LOG_LEVEL_ERROR) - continue - } - - // Make the connection non-blocking - if err = syscall.SetNonblock(connFd, true); err != nil { - logger.LogEvent("Server", "SetNonblock", fmt.Sprintf("setnonblock error: %v", err), logger.LOG_LEVEL_ERROR) - syscall.Close(connFd) - continue - } - event.Events = syscall.EPOLLIN - event.Fd = int32(connFd) - if err := syscall.EpollCtl(server.EpollFD, syscall.EPOLL_CTL_ADD, connFd, &event); err != nil { - logger.LogEvent("Server", "EpollCtl", fmt.Sprintf("epoll_ctl error: %v", err), logger.LOG_LEVEL_ERROR) - syscall.Close(connFd) - continue - } - // Add the client to the list - client, err := server.GetClient(connFd) - if err != nil { - logger.LogEvent("Server", "GetClient", fmt.Sprintf("getclient error: %v", err), logger.LOG_LEVEL_ERROR) - continue - } - if !client.IP.IsPrivate() { - logger.LogEvent("Server", "GetClient", fmt.Sprintf("client %s:%d is not in a private range", client.IP, client.Port), logger.LOG_LEVEL_ERROR) - syscall.EpollCtl(server.EpollFD, syscall.EPOLL_CTL_DEL, connFd, &event) - syscall.Close(connFd) - continue - } - server.AddClient(client) - } else { - // Handle data - var buffer = make([]byte, 8192) - clientFd := int(events[ev].Fd) - client, err := server.GetConnectedClient(clientFd) - if err != nil { - logger.LogEvent("Server", "GetConnectedClient", fmt.Sprintf("%v", err), logger.LOG_LEVEL_ERROR) - server.RemoveClient(client) - continue - } - n, err := syscall.Read(clientFd, buffer) - if err != nil { // the client probably closed the connection - logger.LogEvent("Server", "Read", fmt.Sprintf("%v", err), logger.LOG_LEVEL_ERROR) - } else if n > 0 { - buffer = buffer[:n] - if len(buffer) >= 7 { - server.Dispatcher(&buffer, client) - } - } else { - // EOF, delete from epoll - server.RemoveClient(client) - } - } + n, err := conn.Read(buffer) + if err == io.EOF || err != nil { + conn.Close() + server.RemoveClient(client) + break } - if treatedEvents != nevents { - panic(fmt.Errorf("treated %d events out of %d", treatedEvents, nevents)) + // copy the buffer to the packerBuffer + copy(packerBuffer[totalBytes:], buffer[:n]) + totalBytes += n + + // To know if we have atleast a full message, check first 2 bytes of the packerBuffer + // these two bytes are the length of a message + size := int(packerBuffer[0])<<8 | int(packerBuffer[1]) + 2 // take into account the 2 bytes for the size + if totalBytes >= size { + // We have a full message, slice it and send it to the dispatcher + message := packerBuffer[:size] + server.Dispatcher(&message, client, size) + // Remove the message from the packerBuffer and shift the rest of the buffer + packerBuffer = packerBuffer[size:] + totalBytes -= size + } else { + // Otherwise, wait for more data + continue } } } -func (server *Server) Kill() { - logger.LogEvent("Server", "Kill()", "Closing server", logger.LOG_LEVEL_INFO) - if err := syscall.Close(server.SocketFD); err != nil { - logger.LogEvent("Server", "Kill()", fmt.Sprintf("error closing socket: %v", err), logger.LOG_LEVEL_ERROR) - } - if err := syscall.Close(server.EpollFD); err != nil { - logger.LogEvent("Server", "Kill()", fmt.Sprintf("error closing epoll: %v", err), logger.LOG_LEVEL_ERROR) +func (server *Server) Run() { + listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", server.BindAddress, server.Port)) + if err != nil { + logger.LogEvent("Server", "Run", fmt.Sprintf("error listening: %v", err), logger.LOG_LEVEL_ERROR) + return } - // Close all clients - for _, client := range server.Clients { - client.Kill() + defer listener.Close() + logger.LogEvent("Server", "Run", fmt.Sprintf("listening on %s:%d", server.BindAddress, server.Port), logger.LOG_LEVEL_INFO) + + var wg sync.WaitGroup + for { + conn, err := listener.Accept() + if err != nil { + logger.LogEvent("Server", "Run", fmt.Sprintf("error accepting: %v", err), logger.LOG_LEVEL_ERROR) + continue + } + wg.Add(1) + go handleConnection(conn, &wg, server) } + wg.Wait() } func NewServer(bindAddress string, port int, dispatcher ServerDispatcher) *Server { @@ -205,7 +135,6 @@ func NewServer(bindAddress string, port int, dispatcher ServerDispatcher) *Serve BindAddress: bindAddress, Port: port, Dispatcher: dispatcher, - Clients: make(map[int]*Client), Region: os.Getenv("AL_REGION"), rooms: make(map[uint32][]*Client), } @@ -250,7 +179,6 @@ func (server *Server) SendMessage(sender *Client, message orm.Message) { } } -// TODO: Expose publicly these functions, and delete the package `packets` func GeneratePacketHeader(packetId int, payload *[]byte, packetIndex int) []byte { var buffer bytes.Buffer diff --git a/main.go b/main.go index 9016ef9..ba6ad07 100644 --- a/main.go +++ b/main.go @@ -69,8 +69,6 @@ func main() { go func() { <-sigChannel fmt.Printf("\r") // trick to avoid ^C in the terminal, could use low-level RawMode() but why bother - server.Kill() - tty.Close() os.Exit(0) }() // Prepare web server diff --git a/packets/handler.go b/packets/handler.go index 44efa01..4f0cefa 100644 --- a/packets/handler.go +++ b/packets/handler.go @@ -63,9 +63,9 @@ func RegisterLocalizedPacketHandler(packetId int, localizedHandler LocalizedHand } // Find each packet in the buffer and dispatch it to the appropriate handler. -func Dispatch(buffer *[]byte, client *connection.Client) { +func Dispatch(buffer *[]byte, client *connection.Client, n int) { offset := 0 - for offset < len(*buffer) { + for offset < n { packetId := GetPacketId(offset, buffer) packetSize := GetPacketSize(offset, buffer) + 2 client.PacketIndex = GetPacketIndex(offset, buffer) From 4d1442a216b8b398ba31ea4f2767a1d28b867edc Mon Sep 17 00:00:00 2001 From: ggmolly Date: Sat, 15 Jun 2024 14:23:23 +0200 Subject: [PATCH 2/6] fix: missing error handling --- connection/server.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/connection/server.go b/connection/server.go index 04f51f5..6dfe502 100644 --- a/connection/server.go +++ b/connection/server.go @@ -108,11 +108,11 @@ func handleConnection(conn net.Conn, wg *sync.WaitGroup, server *Server) { } } -func (server *Server) Run() { +func (server *Server) Run() error { listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", server.BindAddress, server.Port)) if err != nil { logger.LogEvent("Server", "Run", fmt.Sprintf("error listening: %v", err), logger.LOG_LEVEL_ERROR) - return + return err } defer listener.Close() logger.LogEvent("Server", "Run", fmt.Sprintf("listening on %s:%d", server.BindAddress, server.Port), logger.LOG_LEVEL_INFO) @@ -128,6 +128,7 @@ func (server *Server) Run() { go handleConnection(conn, &wg, server) } wg.Wait() + return nil } func NewServer(bindAddress string, port int, dispatcher ServerDispatcher) *Server { From 37be9bb5627eb5391a176b4c85e1cac4dafb3bcf Mon Sep 17 00:00:00 2001 From: ggmolly Date: Sat, 15 Jun 2024 16:56:35 +0200 Subject: [PATCH 3/6] add: `Hash` method to clients for mapping --- connection/client.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/connection/client.go b/connection/client.go index ce57375..0283d27 100644 --- a/connection/client.go +++ b/connection/client.go @@ -9,6 +9,7 @@ import ( "github.com/ggmolly/belfast/logger" "github.com/ggmolly/belfast/orm" + "github.com/ggmolly/belfast/protobuf" "google.golang.org/protobuf/proto" ) @@ -21,6 +22,7 @@ type Client struct { Port int State int PacketIndex int + Hash uint32 Connection *net.Conn Commander *orm.Commander Buffer bytes.Buffer @@ -101,6 +103,14 @@ func (client *Client) GetCommander(accountId uint32) error { return err } +// Sends SC_10999 (disconnected from server) message to the Client, reasons are defined in consts/disconnect_reasons.go +func (client *Client) Disconnect(reason uint8) error { + _, _, err := SendProtoMessage(10999, client, &protobuf.SC_10999{ + Reason: proto.Uint32(uint32(reason)), + }) + return err +} + // Sends the content of the buffer to the client via TCP func (client *Client) Flush() { _, err := (*client.Connection).Write(client.Buffer.Bytes()) From f8548f4edd6b9532b07f4ab2c035411a947b9c5b Mon Sep 17 00:00:00 2001 From: ggmolly Date: Sat, 15 Jun 2024 16:56:50 +0200 Subject: [PATCH 4/6] update: disconnect all clients when process receives a SIGINT --- main.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main.go b/main.go index ba6ad07..449de87 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "github.com/akamensky/argparse" "github.com/ggmolly/belfast/answer" "github.com/ggmolly/belfast/connection" + "github.com/ggmolly/belfast/consts" "github.com/ggmolly/belfast/debug" "github.com/ggmolly/belfast/logger" "github.com/ggmolly/belfast/misc" @@ -69,6 +70,8 @@ func main() { go func() { <-sigChannel fmt.Printf("\r") // trick to avoid ^C in the terminal, could use low-level RawMode() but why bother + // disconnect all clients from the server + server.DisconnectAll(consts.DR_CONNECTION_TO_SERVER_LOST) os.Exit(0) }() // Prepare web server From 9e01b75fedafdf7104c19dde80c09f29b4cca427 Mon Sep 17 00:00:00 2001 From: ggmolly Date: Sat, 15 Jun 2024 16:57:00 +0200 Subject: [PATCH 5/6] add: disconnection reason --- consts/disconnect_reasons.go | 37 ++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 consts/disconnect_reasons.go diff --git a/consts/disconnect_reasons.go b/consts/disconnect_reasons.go new file mode 100644 index 0000000..b18d7f9 --- /dev/null +++ b/consts/disconnect_reasons.go @@ -0,0 +1,37 @@ +package consts + +import "fmt" + +const ( + DR_LOGGED_IN_ON_ANOTHER_DEVICE = 1 + DR_SERVER_MAINTENANCE = 2 + DR_GAME_UPDATE = 3 + DR_OFFLINE_TOO_LONG = 4 + DR_CONNECTION_LOST = 5 + DR_CONNECTION_TO_SERVER_LOST = 6 + DR_DATA_VALIDATION_FAILED = 7 + DR_LOGIN_DATA_EXPIRED = 199 +) + +func ResolveReason(reason uint8) string { + switch reason { + case DR_LOGGED_IN_ON_ANOTHER_DEVICE: + return "logged in on another device" + case DR_SERVER_MAINTENANCE: + return "server maintenance" + case DR_GAME_UPDATE: + return "game update" + case DR_OFFLINE_TOO_LONG: + return "offline too long" + case DR_CONNECTION_LOST: + return "connection lost" + case DR_CONNECTION_TO_SERVER_LOST: + return "connection to server lost" + case DR_DATA_VALIDATION_FAILED: + return "data validation failed" + case DR_LOGIN_DATA_EXPIRED: + return "login data expired" + default: + return fmt.Sprintf("unknown reason %d", reason) + } +} From 24875febc74111e2461634b93d9347c0af0bf019 Mon Sep 17 00:00:00 2001 From: ggmolly Date: Sat, 15 Jun 2024 16:57:30 +0200 Subject: [PATCH 6/6] add: mutexes to avoid concurrents r/w --- connection/server.go | 59 +++++++++++++++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/connection/server.go b/connection/server.go index 6dfe502..dcf016a 100644 --- a/connection/server.go +++ b/connection/server.go @@ -9,6 +9,7 @@ import ( "reflect" "sync" + "github.com/ggmolly/belfast/consts" "github.com/ggmolly/belfast/debug" "github.com/ggmolly/belfast/logger" "github.com/ggmolly/belfast/orm" @@ -24,8 +25,13 @@ type Server struct { SocketFD int EpollFD int Dispatcher ServerDispatcher - rooms map[uint32][]*Client Region string + + // Maps & mutexes + roomsMutex sync.RWMutex + rooms map[uint32][]*Client // Game chat rooms + clientsMutex sync.RWMutex + clients map[uint32]*Client // Socket hash -> Client } var ( @@ -39,36 +45,43 @@ func (server *Server) GetClient(conn *net.Conn) (*Client, error) { client.Port = (*conn).RemoteAddr().(*net.TCPAddr).Port client.Connection = conn client.Server = server + for _, c := range fmt.Sprintf("%s:%d", client.IP, client.Port) { + client.Hash += uint32(c) + } return &client, err } func (server *Server) AddClient(client *Client) { - logger.LogEvent("Server", "hewwo", fmt.Sprintf("new connection from %s:%d", client.IP, client.Port), logger.LOG_LEVEL_DEBUG) + logger.LogEvent("Server", "Hello", fmt.Sprintf("new connection from %s:%d", client.IP, client.Port), logger.LOG_LEVEL_DEBUG) + client.Server.clientsMutex.Lock() + defer client.Server.clientsMutex.Unlock() + server.clients[client.Hash] = client } func (server *Server) RemoveClient(client *Client) { - logger.LogEvent("Server", "cya", fmt.Sprintf("%s:%d", client.IP, client.Port), logger.LOG_LEVEL_DEBUG) + client.Server.clientsMutex.Lock() + defer client.Server.clientsMutex.Unlock() + logger.LogEvent("Server", "Goodbye", fmt.Sprintf("%s:%d", client.IP, client.Port), logger.LOG_LEVEL_DEBUG) (*client.Connection).Close() + delete(server.clients, client.Hash) } -func handleConnection(conn net.Conn, wg *sync.WaitGroup, server *Server) { - defer wg.Done() +func handleConnection(conn net.Conn, server *Server) { + logger.LogEvent("Server", "TEST", "Goroutine started", logger.LOG_LEVEL_WARN) defer conn.Close() - + defer logger.LogEvent("Server", "TEST", "Goroutine ended", logger.LOG_LEVEL_WARN) // Add the client to the list client, err := server.GetClient(&conn) if err != nil { logger.LogEvent("Server", "Handler", fmt.Sprintf("client %s -- error: %v", conn.RemoteAddr(), err), logger.LOG_LEVEL_ERROR) conn.Close() - server.RemoveClient(client) return } if !client.IP.IsPrivate() { logger.LogEvent("Server", "Handler", fmt.Sprintf("client %s -- not in a private range", conn.RemoteAddr()), logger.LOG_LEVEL_ERROR) conn.Close() - server.RemoveClient(client) return } @@ -95,7 +108,7 @@ func handleConnection(conn net.Conn, wg *sync.WaitGroup, server *Server) { // these two bytes are the length of a message size := int(packerBuffer[0])<<8 | int(packerBuffer[1]) + 2 // take into account the 2 bytes for the size if totalBytes >= size { - // We have a full message, slice it and send it to the dispatcher + // Slice the packerBuffer to get the message and send it to the dispatcher message := packerBuffer[:size] server.Dispatcher(&message, client, size) // Remove the message from the packerBuffer and shift the rest of the buffer @@ -117,18 +130,14 @@ func (server *Server) Run() error { defer listener.Close() logger.LogEvent("Server", "Run", fmt.Sprintf("listening on %s:%d", server.BindAddress, server.Port), logger.LOG_LEVEL_INFO) - var wg sync.WaitGroup for { conn, err := listener.Accept() if err != nil { logger.LogEvent("Server", "Run", fmt.Sprintf("error accepting: %v", err), logger.LOG_LEVEL_ERROR) continue } - wg.Add(1) - go handleConnection(conn, &wg, server) + go handleConnection(conn, server) } - wg.Wait() - return nil } func NewServer(bindAddress string, port int, dispatcher ServerDispatcher) *Server { @@ -137,16 +146,34 @@ func NewServer(bindAddress string, port int, dispatcher ServerDispatcher) *Serve Port: port, Dispatcher: dispatcher, Region: os.Getenv("AL_REGION"), + clients: make(map[uint32]*Client), rooms: make(map[uint32][]*Client), } } +// Sends SC_10999 (disconnected from server) message to every connected clients, reasons are defined in consts/disconnect_reasons.go +func (server *Server) DisconnectAll(reason uint8) { + server.clientsMutex.Lock() + defer server.clientsMutex.Unlock() + for _, client := range server.clients { + logger.LogEvent("Server", "Disconnect", fmt.Sprintf("disconnecting %s:%d -> %s", client.IP, client.Port, consts.ResolveReason(reason)), logger.LOG_LEVEL_DEBUG) + client.Disconnect(reason) + client.Flush() + (*client.Connection).Close() + delete(server.clients, client.Hash) + } +} + // Chat room management func (server *Server) JoinRoom(roomID uint32, client *Client) { + server.roomsMutex.Lock() + defer server.roomsMutex.Unlock() server.rooms[roomID] = append(server.rooms[roomID], client) } func (server *Server) LeaveRoom(roomID uint32, client *Client) { + server.roomsMutex.Lock() + defer server.roomsMutex.Unlock() for i, c := range server.rooms[roomID] { if c == client { server.rooms[roomID] = append(server.rooms[roomID][:i], server.rooms[roomID][i+1:]...) @@ -156,6 +183,8 @@ func (server *Server) LeaveRoom(roomID uint32, client *Client) { } func (server *Server) ChangeRoom(oldRoomID uint32, newRoomID uint32, client *Client) { + server.roomsMutex.Lock() + defer server.roomsMutex.Unlock() for i, c := range server.rooms[oldRoomID] { if c == client { server.rooms[oldRoomID] = append(server.rooms[oldRoomID][:i], server.rooms[oldRoomID][i+1:]...) @@ -175,6 +204,8 @@ func (server *Server) SendMessage(sender *Client, message orm.Message) { Type: proto.Uint32(orm.MSG_TYPE_NORMAL), Content: proto.String(message.Content), } + server.roomsMutex.RLock() + defer server.roomsMutex.RUnlock() for _, client := range server.rooms[message.RoomID] { client.SendMessage(50101, &msgPacket) }