From 7477d63af3aeb6407d4e97a1e9e700f41bc88652 Mon Sep 17 00:00:00 2001 From: nick Date: Mon, 29 Jul 2024 17:35:33 +0900 Subject: [PATCH] feat: implement ip limit --- node/pkg/dal/api/hub.go | 102 ++++++++++++++++++++++++-------------- node/pkg/dal/api/types.go | 3 ++ 2 files changed, 69 insertions(+), 36 deletions(-) diff --git a/node/pkg/dal/api/hub.go b/node/pkg/dal/api/hub.go index ac1c35f5c..ed115cb14 100644 --- a/node/pkg/dal/api/hub.go +++ b/node/pkg/dal/api/hub.go @@ -29,47 +29,77 @@ func NewHub(configs map[string]types.Config) *Hub { register: make(chan *websocket.Conn), unregister: make(chan *websocket.Conn), broadcast: make(map[string]chan dalcommon.OutgoingSubmissionData), + connPerIP: make(map[string][]*websocket.Conn), } } -func (c *Hub) Start(ctx context.Context, collector *collector.Collector) { - go c.handleClientRegistration() +func (h *Hub) Start(ctx context.Context, collector *collector.Collector) { + go h.handleClientRegistration() - c.initializeBroadcastChannels(collector) + h.initializeBroadcastChannels(collector) - for symbol := range c.configs { - go c.broadcastDataForSymbol(symbol) + for symbol := range h.configs { + go h.broadcastDataForSymbol(symbol) } } -func (c *Hub) handleClientRegistration() { +func (h *Hub) handleClientRegistration() { for { select { - case conn := <-c.register: - c.addClient(conn) - case conn := <-c.unregister: - c.removeClient(conn) + case conn := <-h.register: + h.addClient(conn) + case conn := <-h.unregister: + h.removeClient(conn) } } } -func (c *Hub) addClient(conn *websocket.Conn) { - c.mu.Lock() - defer c.mu.Unlock() - if _, ok := c.clients[conn]; ok { +func (h *Hub) addClient(conn *websocket.Conn) { + h.mu.Lock() + defer h.mu.Unlock() + if _, ok := h.clients[conn]; ok { return } - c.clients[conn] = make(map[string]bool) + h.clients[conn] = make(map[string]bool) + if _, ok := h.connPerIP[conn.IP()]; !ok { + h.connPerIP[conn.IP()] = make([]*websocket.Conn, 0) + } + + h.connPerIP[conn.IP()] = append(h.connPerIP[conn.IP()], conn) + + if len(h.connPerIP) > MAX_CONNECTIONS { + oldConn := h.connPerIP[conn.IP()][0] + if subs, ok := h.clients[oldConn]; ok { + for k := range subs { + delete(h.clients[oldConn], k) + } + } + delete(h.clients, oldConn) + oldConn.WriteControl( + websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "too many connections"), + time.Now().Add(time.Second), + ) + oldConn.Close() + } } -func (c *Hub) removeClient(conn *websocket.Conn) { - c.mu.Lock() - defer c.mu.Unlock() - if _, ok := c.clients[conn]; ok { - for symbol := range c.clients[conn] { - delete(c.clients[conn], symbol) +func (h *Hub) removeClient(conn *websocket.Conn) { + h.mu.Lock() + defer h.mu.Unlock() + if _, ok := h.clients[conn]; ok { + for symbol := range h.clients[conn] { + delete(h.clients[conn], symbol) + } + delete(h.clients, conn) + } + for i, c := range h.connPerIP[conn.IP()] { + if c == conn { + h.connPerIP[conn.IP()] = append(h.connPerIP[conn.IP()][:i], h.connPerIP[conn.IP()][i+1:]...) + if len(h.connPerIP) == 0 { + delete(h.connPerIP, conn.IP()) + } } - delete(c.clients, conn) } if err := conn.WriteControl( websocket.CloseMessage, @@ -81,15 +111,15 @@ func (c *Hub) removeClient(conn *websocket.Conn) { conn.Close() } -func (c *Hub) initializeBroadcastChannels(collector *collector.Collector) { +func (h *Hub) initializeBroadcastChannels(collector *collector.Collector) { for configId, stream := range collector.OutgoingStream { - symbol := c.configIdToSymbol(configId) - c.broadcast[symbol] = stream + symbol := h.configIdToSymbol(configId) + h.broadcast[symbol] = stream } } -func (c *Hub) configIdToSymbol(id int32) string { - for symbol, config := range c.configs { +func (h *Hub) configIdToSymbol(id int32) string { + for symbol, config := range h.configs { if config.ID == id { return symbol } @@ -97,21 +127,21 @@ func (c *Hub) configIdToSymbol(id int32) string { return "" } -func (c *Hub) broadcastDataForSymbol(symbol string) { - for data := range c.broadcast[symbol] { - c.castSubmissionData(&data, &symbol) +func (h *Hub) broadcastDataForSymbol(symbol string) { + for data := range h.broadcast[symbol] { + h.castSubmissionData(&data, &symbol) } } // pass by pointer to reduce memory copy time -func (c *Hub) castSubmissionData(data *dalcommon.OutgoingSubmissionData, symbol *string) { - c.mu.Lock() - defer c.mu.Unlock() - for conn := range c.clients { - if _, ok := c.clients[conn][*symbol]; ok { +func (h *Hub) castSubmissionData(data *dalcommon.OutgoingSubmissionData, symbol *string) { + h.mu.Lock() + defer h.mu.Unlock() + for conn := range h.clients { + if _, ok := h.clients[conn][*symbol]; ok { if err := conn.WriteJSON(*data); err != nil { log.Error().Err(err).Msg("failed to write message") - c.unregister <- conn + h.unregister <- conn } } } diff --git a/node/pkg/dal/api/types.go b/node/pkg/dal/api/types.go index eb7128576..1c49a86ac 100644 --- a/node/pkg/dal/api/types.go +++ b/node/pkg/dal/api/types.go @@ -8,6 +8,8 @@ import ( "github.com/gofiber/contrib/websocket" ) +const MAX_CONNECTIONS = 10 + type Subscription struct { Method string `json:"method"` Params []string `json:"params"` @@ -19,6 +21,7 @@ type Hub struct { register chan *websocket.Conn unregister chan *websocket.Conn broadcast map[string]chan dalcommon.OutgoingSubmissionData + connPerIP map[string][]*websocket.Conn mu sync.RWMutex }