Skip to content

Commit

Permalink
(DAL) Refactor package structure (#1925)
Browse files Browse the repository at this point in the history
* refactor: refactor package structure

* fix: update based on feedback

* Update node/pkg/dal/utils/initializer/initializer.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update node/pkg/dal/utils/initializer/initializer.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* fix: remove redundant code

* Apply suggestions from code review

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update node/pkg/dal/api/controller.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* feat: delete also subscription on unregister (#1928)

* feat: closeHandler update (#1927)

* (DAL) writeControl (#1926)

* feat: writeControl

* Update node/pkg/dal/api/hub.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* fix: remove unnecessary codes

* feat: modularize, update readability

* fix: remove unnecessary code

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
nick-bisonai and coderabbitai[bot] authored Jul 29, 2024
1 parent a278af1 commit 0bc41a5
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 148 deletions.
153 changes: 31 additions & 122 deletions node/pkg/dal/api/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,124 +6,34 @@ import (
"fmt"
"strings"

"bisonai.com/orakl/node/pkg/common/types"
"bisonai.com/orakl/node/pkg/dal/collector"
dalcommon "bisonai.com/orakl/node/pkg/dal/common"
"bisonai.com/orakl/node/pkg/dal/utils/stats"
"bisonai.com/orakl/node/pkg/utils/request"
"github.com/gofiber/contrib/websocket"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)

func Setup(ctx context.Context, adminEndpoint string) (*Controller, error) {
configs, err := request.Request[[]types.Config](request.WithEndpoint(adminEndpoint + "/config"))
if err != nil {
log.Error().Err(err).Msg("failed to get configs")
return nil, err
}

configMap := make(map[string]types.Config)
for _, config := range configs {
configMap[config.Name] = config
}
collector, err := collector.NewCollector(ctx, configs)
if err != nil {
log.Error().Err(err).Msg("failed to create collector")
return nil, err
}

ApiController := NewController(configMap, collector)
return ApiController, nil
}

func NewController(configs map[string]types.Config, internalCollector *collector.Collector) *Controller {
return &Controller{
Collector: internalCollector,
configs: configs,

clients: make(map[*websocket.Conn]map[string]bool),
register: make(chan *websocket.Conn),
unregister: make(chan *websocket.Conn),
broadcast: make(map[string]chan dalcommon.OutgoingSubmissionData),
}
}

func (c *Controller) Start(ctx context.Context) {
go c.Collector.Start(ctx)
log.Info().Msg("api collector started")
go func() {
for {
select {
case conn := <-c.register:
c.mu.Lock()
c.clients[conn] = make(map[string]bool)
c.mu.Unlock()
case conn := <-c.unregister:
c.mu.Lock()
delete(c.clients, conn)
conn.Close()
c.mu.Unlock()
}
}
}()

for configId, stream := range c.Collector.OutgoingStream {
symbol := c.configIdToSymbol(configId)
c.broadcast[symbol] = make(chan dalcommon.OutgoingSubmissionData)
c.broadcast[symbol] = stream
}

for symbol := range c.configs {
go c.broadcastDataForSymbol(symbol)
}
}

func (c *Controller) configIdToSymbol(id int32) string {
for symbol, config := range c.configs {
if config.ID == id {
return symbol
}
}
return ""
}

func (c *Controller) broadcastDataForSymbol(symbol string) {
for data := range c.broadcast[symbol] {

go c.castSubmissionData(&data, &symbol)
}
}

// pass by pointer to reduce memory copy time
func (c *Controller) castSubmissionData(data *dalcommon.OutgoingSubmissionData, symbol *string) {
c.mu.Lock()
defer c.mu.Unlock()
for conn := range c.clients {
if _, ok := c.clients[conn][*symbol]; ok {
if err := conn.WriteJSON(*data); err != nil {
log.Error().Err(err).Msg("failed to write message")
delete(c.clients, conn)
conn.Close()
}
}
}
}

func HandleWebsocket(conn *websocket.Conn) {
c, ok := conn.Locals("apiController").(*Controller)
h, ok := conn.Locals("hub").(*Hub)
if !ok {
log.Error().Msg("api controller not found")
log.Error().Msg("hub not found")
return
}

closeHandler := conn.CloseHandler()
conn.SetCloseHandler(func(code int, text string) error {
h.unregister <- conn
return closeHandler(code, text)
})

ctx, ok := conn.Locals("context").(*context.Context)
if !ok {
log.Error().Msg("ctx not found")
return
}

c.register <- conn
h.register <- conn
apiKey := conn.Headers("X-Api-Key")

id, err := stats.InsertWebsocketConnection(*ctx, apiKey)
Expand All @@ -134,8 +44,7 @@ func HandleWebsocket(conn *websocket.Conn) {
log.Info().Int32("id", id).Msg("inserted websocket connection")

defer func() {
c.unregister <- conn
conn.Close()
h.unregister <- conn
err = stats.UpdateWebsocketConnection(*ctx, id)
if err != nil {
log.Error().Err(err).Msg("failed to update websocket connection")
Expand All @@ -152,53 +61,53 @@ func HandleWebsocket(conn *websocket.Conn) {
}

if msg.Method == "SUBSCRIBE" {
c.mu.Lock()
if c.clients[conn] == nil {
c.clients[conn] = make(map[string]bool)
h.mu.Lock()
if h.clients[conn] == nil {
h.clients[conn] = make(map[string]bool)
}
for _, param := range msg.Params {
symbol := strings.TrimPrefix(param, "submission@")
if _, ok := c.configs[symbol]; !ok {
if _, ok := h.configs[symbol]; !ok {
continue
}
c.clients[conn][symbol] = true
h.clients[conn][symbol] = true
err = stats.InsertWebsocketSubscription(*ctx, id, param)
if err != nil {
log.Error().Err(err).Msg("failed to insert websocket subscription")
}
}
c.mu.Unlock()
h.mu.Unlock()
}
}
}

func getSymbols(c *fiber.Ctx) error {
controller, ok := c.Locals("apiController").(*Controller)
hub, ok := c.Locals("hub").(*Hub)
if !ok {
return errors.New("api controller not found")
return errors.New("hub not found")
}

result := []string{}
for key := range controller.configs {
for key := range hub.configs {
result = append(result, key)
}
return c.JSON(result)
}

func getAllLatestFeeds(c *fiber.Ctx) error {
controller, ok := c.Locals("apiController").(*Controller)
collector, ok := c.Locals("collector").(*collector.Collector)
if !ok {
return errors.New("api controller not found")
return errors.New("collector not found")
}

result := controller.Collector.GetAllLatestData()
result := collector.GetAllLatestData()
return c.JSON(result)
}

func getLatestFeeds(c *fiber.Ctx) error {
controller, ok := c.Locals("apiController").(*Controller)
collector, ok := c.Locals("collector").(*collector.Collector)
if !ok {
return errors.New("api controller not found")
return errors.New("collector not found")
}

symbolsStr := c.Params("symbols")
Expand All @@ -222,7 +131,7 @@ func getLatestFeeds(c *fiber.Ctx) error {
symbol = strings.ToUpper(symbol)
}

result, err := controller.Collector.GetLatestData(symbol)
result, err := collector.GetLatestData(symbol)
if err != nil {
return err
}
Expand All @@ -234,9 +143,9 @@ func getLatestFeeds(c *fiber.Ctx) error {
}

func getLatestFeedsTransposed(c *fiber.Ctx) error {
controller, ok := c.Locals("apiController").(*Controller)
collector, ok := c.Locals("collector").(*collector.Collector)
if !ok {
return errors.New("api controller not found")
return errors.New("collector not found")
}

symbolsStr := c.Params("symbols")
Expand All @@ -260,7 +169,7 @@ func getLatestFeedsTransposed(c *fiber.Ctx) error {
symbol = strings.ToUpper(symbol)
}

result, err := controller.Collector.GetLatestData(symbol)
result, err := collector.GetLatestData(symbol)
if err != nil {
return err
}
Expand All @@ -276,12 +185,12 @@ func getLatestFeedsTransposed(c *fiber.Ctx) error {
}

func getAllLatestFeedsTransposed(c *fiber.Ctx) error {
controller, ok := c.Locals("apiController").(*Controller)
collector, ok := c.Locals("collector").(*collector.Collector)
if !ok {
return errors.New("api controller not found")
return errors.New("collector not found")
}

result := controller.Collector.GetAllLatestData()
result := collector.GetAllLatestData()
bulk := BulkResponse{}
for _, data := range result {
bulk.Symbols = append(bulk.Symbols, data.Symbol)
Expand Down
118 changes: 118 additions & 0 deletions node/pkg/dal/api/hub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package api

import (
"context"
"time"

"bisonai.com/orakl/node/pkg/common/types"
"bisonai.com/orakl/node/pkg/dal/collector"
dalcommon "bisonai.com/orakl/node/pkg/dal/common"
"github.com/gofiber/contrib/websocket"
"github.com/rs/zerolog/log"
)

func HubSetup(ctx context.Context, configs []types.Config) *Hub {
configMap := make(map[string]types.Config)
for _, config := range configs {
configMap[config.Name] = config
}

hub := NewHub(configMap)
return hub
}

func NewHub(configs map[string]types.Config) *Hub {
return &Hub{
configs: configs,

clients: make(map[*websocket.Conn]map[string]bool),
register: make(chan *websocket.Conn),
unregister: make(chan *websocket.Conn),
broadcast: make(map[string]chan dalcommon.OutgoingSubmissionData),
}
}

func (c *Hub) Start(ctx context.Context, collector *collector.Collector) {
go c.handleClientRegistration()

c.initializeBroadcastChannels(collector)

for symbol := range c.configs {
go c.broadcastDataForSymbol(symbol)
}
}

func (c *Hub) handleClientRegistration() {
for {
select {
case conn := <-c.register:
c.addClient(conn)
case conn := <-c.unregister:
c.removeClient(conn)
}
}
}

func (c *Hub) addClient(conn *websocket.Conn) {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.clients[conn]; ok {
return
}
c.clients[conn] = make(map[string]bool)
}

func (c *Hub) removeClient(conn *websocket.Conn) {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.clients[conn]; ok {
for symbol := range c.clients[conn] {
delete(c.clients[conn], symbol)
}
delete(c.clients, conn)
}
if err := conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
time.Now().Add(time.Second),
); err != nil {
log.Error().Err(err).Msg("failed to send close message")
}
conn.Close()
}

func (c *Hub) initializeBroadcastChannels(collector *collector.Collector) {
for configId, stream := range collector.OutgoingStream {
symbol := c.configIdToSymbol(configId)
c.broadcast[symbol] = stream
}
}

func (c *Hub) configIdToSymbol(id int32) string {
for symbol, config := range c.configs {
if config.ID == id {
return symbol
}
}
return ""
}

func (c *Hub) broadcastDataForSymbol(symbol string) {
for data := range c.broadcast[symbol] {
c.castSubmissionData(&data, &symbol)
}
}

// pass by pointer to reduce memory copy time
func (c *Hub) castSubmissionData(data *dalcommon.OutgoingSubmissionData, symbol *string) {
c.mu.Lock()
defer c.mu.Unlock()
for conn := range c.clients {
if _, ok := c.clients[conn][*symbol]; ok {
if err := conn.WriteJSON(*data); err != nil {
log.Error().Err(err).Msg("failed to write message")
c.unregister <- conn
}
}
}
}
5 changes: 1 addition & 4 deletions node/pkg/dal/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"sync"

"bisonai.com/orakl/node/pkg/common/types"
"bisonai.com/orakl/node/pkg/dal/collector"
dalcommon "bisonai.com/orakl/node/pkg/dal/common"
"github.com/gofiber/contrib/websocket"
)
Expand All @@ -14,9 +13,7 @@ type Subscription struct {
Params []string `json:"params"`
}

type Controller struct {
Collector *collector.Collector

type Hub struct {
configs map[string]types.Config
clients map[*websocket.Conn]map[string]bool
register chan *websocket.Conn
Expand Down
Loading

0 comments on commit 0bc41a5

Please sign in to comment.