Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(DAL) Refactor package structure #1925

Merged
merged 13 commits into from
Jul 29, 2024
153 changes: 31 additions & 122 deletions node/pkg/dal/api/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,124 +6,34 @@ import (
"fmt"
"strings"

"bisonai.com/orakl/node/pkg/common/types"
"bisonai.com/orakl/node/pkg/dal/collector"
dalcommon "bisonai.com/orakl/node/pkg/dal/common"
"bisonai.com/orakl/node/pkg/dal/utils/stats"
"bisonai.com/orakl/node/pkg/utils/request"
"github.com/gofiber/contrib/websocket"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)

func Setup(ctx context.Context, adminEndpoint string) (*Controller, error) {
configs, err := request.Request[[]types.Config](request.WithEndpoint(adminEndpoint + "/config"))
if err != nil {
log.Error().Err(err).Msg("failed to get configs")
return nil, err
}

configMap := make(map[string]types.Config)
for _, config := range configs {
configMap[config.Name] = config
}
collector, err := collector.NewCollector(ctx, configs)
if err != nil {
log.Error().Err(err).Msg("failed to create collector")
return nil, err
}

ApiController := NewController(configMap, collector)
return ApiController, nil
}

func NewController(configs map[string]types.Config, internalCollector *collector.Collector) *Controller {
return &Controller{
Collector: internalCollector,
configs: configs,

clients: make(map[*websocket.Conn]map[string]bool),
register: make(chan *websocket.Conn),
unregister: make(chan *websocket.Conn),
broadcast: make(map[string]chan dalcommon.OutgoingSubmissionData),
}
}

func (c *Controller) Start(ctx context.Context) {
go c.Collector.Start(ctx)
log.Info().Msg("api collector started")
go func() {
for {
select {
case conn := <-c.register:
c.mu.Lock()
c.clients[conn] = make(map[string]bool)
c.mu.Unlock()
case conn := <-c.unregister:
c.mu.Lock()
delete(c.clients, conn)
conn.Close()
c.mu.Unlock()
}
}
}()

for configId, stream := range c.Collector.OutgoingStream {
symbol := c.configIdToSymbol(configId)
c.broadcast[symbol] = make(chan dalcommon.OutgoingSubmissionData)
c.broadcast[symbol] = stream
}

for symbol := range c.configs {
go c.broadcastDataForSymbol(symbol)
}
}

func (c *Controller) configIdToSymbol(id int32) string {
for symbol, config := range c.configs {
if config.ID == id {
return symbol
}
}
return ""
}

func (c *Controller) broadcastDataForSymbol(symbol string) {
for data := range c.broadcast[symbol] {

go c.castSubmissionData(&data, &symbol)
}
}

// pass by pointer to reduce memory copy time
func (c *Controller) castSubmissionData(data *dalcommon.OutgoingSubmissionData, symbol *string) {
c.mu.Lock()
defer c.mu.Unlock()
for conn := range c.clients {
if _, ok := c.clients[conn][*symbol]; ok {
if err := conn.WriteJSON(*data); err != nil {
log.Error().Err(err).Msg("failed to write message")
delete(c.clients, conn)
conn.Close()
}
}
}
}

func HandleWebsocket(conn *websocket.Conn) {
c, ok := conn.Locals("apiController").(*Controller)
h, ok := conn.Locals("hub").(*Hub)
if !ok {
log.Error().Msg("api controller not found")
log.Error().Msg("hub not found")
return
}

closeHandler := conn.CloseHandler()
conn.SetCloseHandler(func(code int, text string) error {
h.unregister <- conn
return closeHandler(code, text)
})

ctx, ok := conn.Locals("context").(*context.Context)
if !ok {
log.Error().Msg("ctx not found")
return
}

c.register <- conn
h.register <- conn
apiKey := conn.Headers("X-Api-Key")

id, err := stats.InsertWebsocketConnection(*ctx, apiKey)
Expand All @@ -134,8 +44,7 @@ func HandleWebsocket(conn *websocket.Conn) {
log.Info().Int32("id", id).Msg("inserted websocket connection")

defer func() {
c.unregister <- conn
conn.Close()
h.unregister <- conn
err = stats.UpdateWebsocketConnection(*ctx, id)
if err != nil {
log.Error().Err(err).Msg("failed to update websocket connection")
Expand All @@ -152,53 +61,53 @@ func HandleWebsocket(conn *websocket.Conn) {
}

if msg.Method == "SUBSCRIBE" {
c.mu.Lock()
if c.clients[conn] == nil {
c.clients[conn] = make(map[string]bool)
h.mu.Lock()
if h.clients[conn] == nil {
h.clients[conn] = make(map[string]bool)
}
for _, param := range msg.Params {
symbol := strings.TrimPrefix(param, "submission@")
if _, ok := c.configs[symbol]; !ok {
if _, ok := h.configs[symbol]; !ok {
continue
}
c.clients[conn][symbol] = true
h.clients[conn][symbol] = true
err = stats.InsertWebsocketSubscription(*ctx, id, param)
if err != nil {
log.Error().Err(err).Msg("failed to insert websocket subscription")
}
}
c.mu.Unlock()
h.mu.Unlock()
}
}
}

func getSymbols(c *fiber.Ctx) error {
controller, ok := c.Locals("apiController").(*Controller)
hub, ok := c.Locals("hub").(*Hub)
if !ok {
return errors.New("api controller not found")
return errors.New("hub not found")
}

result := []string{}
for key := range controller.configs {
for key := range hub.configs {
result = append(result, key)
}
return c.JSON(result)
}

func getAllLatestFeeds(c *fiber.Ctx) error {
controller, ok := c.Locals("apiController").(*Controller)
collector, ok := c.Locals("collector").(*collector.Collector)
if !ok {
return errors.New("api controller not found")
return errors.New("collector not found")
}

result := controller.Collector.GetAllLatestData()
result := collector.GetAllLatestData()
return c.JSON(result)
}

func getLatestFeeds(c *fiber.Ctx) error {
controller, ok := c.Locals("apiController").(*Controller)
collector, ok := c.Locals("collector").(*collector.Collector)
if !ok {
return errors.New("api controller not found")
return errors.New("collector not found")
}

symbolsStr := c.Params("symbols")
Expand All @@ -222,7 +131,7 @@ func getLatestFeeds(c *fiber.Ctx) error {
symbol = strings.ToUpper(symbol)
}

result, err := controller.Collector.GetLatestData(symbol)
result, err := collector.GetLatestData(symbol)
if err != nil {
return err
}
Expand All @@ -234,9 +143,9 @@ func getLatestFeeds(c *fiber.Ctx) error {
}

func getLatestFeedsTransposed(c *fiber.Ctx) error {
controller, ok := c.Locals("apiController").(*Controller)
collector, ok := c.Locals("collector").(*collector.Collector)
if !ok {
return errors.New("api controller not found")
return errors.New("collector not found")
}

symbolsStr := c.Params("symbols")
Expand All @@ -260,7 +169,7 @@ func getLatestFeedsTransposed(c *fiber.Ctx) error {
symbol = strings.ToUpper(symbol)
}

result, err := controller.Collector.GetLatestData(symbol)
result, err := collector.GetLatestData(symbol)
if err != nil {
return err
}
Expand All @@ -276,12 +185,12 @@ func getLatestFeedsTransposed(c *fiber.Ctx) error {
}

func getAllLatestFeedsTransposed(c *fiber.Ctx) error {
controller, ok := c.Locals("apiController").(*Controller)
collector, ok := c.Locals("collector").(*collector.Collector)
if !ok {
return errors.New("api controller not found")
return errors.New("collector not found")
}

result := controller.Collector.GetAllLatestData()
result := collector.GetAllLatestData()
bulk := BulkResponse{}
for _, data := range result {
bulk.Symbols = append(bulk.Symbols, data.Symbol)
Expand Down
118 changes: 118 additions & 0 deletions node/pkg/dal/api/hub.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package api

import (
"context"
"time"

"bisonai.com/orakl/node/pkg/common/types"
"bisonai.com/orakl/node/pkg/dal/collector"
dalcommon "bisonai.com/orakl/node/pkg/dal/common"
"github.com/gofiber/contrib/websocket"
"github.com/rs/zerolog/log"
)

func HubSetup(ctx context.Context, configs []types.Config) *Hub {
configMap := make(map[string]types.Config)
for _, config := range configs {
configMap[config.Name] = config
}

hub := NewHub(configMap)
return hub
}

func NewHub(configs map[string]types.Config) *Hub {
return &Hub{
configs: configs,

clients: make(map[*websocket.Conn]map[string]bool),
register: make(chan *websocket.Conn),
unregister: make(chan *websocket.Conn),
broadcast: make(map[string]chan dalcommon.OutgoingSubmissionData),
}
}

func (c *Hub) Start(ctx context.Context, collector *collector.Collector) {
go c.handleClientRegistration()

c.initializeBroadcastChannels(collector)

for symbol := range c.configs {
go c.broadcastDataForSymbol(symbol)
}
}

func (c *Hub) handleClientRegistration() {
for {
select {
case conn := <-c.register:
c.addClient(conn)
case conn := <-c.unregister:
c.removeClient(conn)
}
}
}

func (c *Hub) addClient(conn *websocket.Conn) {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.clients[conn]; ok {
return
}
c.clients[conn] = make(map[string]bool)
}

func (c *Hub) removeClient(conn *websocket.Conn) {
c.mu.Lock()
defer c.mu.Unlock()
if _, ok := c.clients[conn]; ok {
for symbol := range c.clients[conn] {
delete(c.clients[conn], symbol)
}
delete(c.clients, conn)
}
if err := conn.WriteControl(
websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
time.Now().Add(time.Second),
); err != nil {
log.Error().Err(err).Msg("failed to send close message")
}
conn.Close()
}

func (c *Hub) initializeBroadcastChannels(collector *collector.Collector) {
for configId, stream := range collector.OutgoingStream {
symbol := c.configIdToSymbol(configId)
c.broadcast[symbol] = stream
}
}

func (c *Hub) configIdToSymbol(id int32) string {
for symbol, config := range c.configs {
if config.ID == id {
return symbol
}
}
return ""
}

func (c *Hub) broadcastDataForSymbol(symbol string) {
for data := range c.broadcast[symbol] {
c.castSubmissionData(&data, &symbol)
}
}

// pass by pointer to reduce memory copy time
func (c *Hub) castSubmissionData(data *dalcommon.OutgoingSubmissionData, symbol *string) {
c.mu.Lock()
defer c.mu.Unlock()
for conn := range c.clients {
if _, ok := c.clients[conn][*symbol]; ok {
if err := conn.WriteJSON(*data); err != nil {
log.Error().Err(err).Msg("failed to write message")
c.unregister <- conn
}
}
}
}
5 changes: 1 addition & 4 deletions node/pkg/dal/api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"sync"

"bisonai.com/orakl/node/pkg/common/types"
"bisonai.com/orakl/node/pkg/dal/collector"
dalcommon "bisonai.com/orakl/node/pkg/dal/common"
"github.com/gofiber/contrib/websocket"
)
Expand All @@ -14,9 +13,7 @@ type Subscription struct {
Params []string `json:"params"`
}

type Controller struct {
Collector *collector.Collector

type Hub struct {
configs map[string]types.Config
clients map[*websocket.Conn]map[string]bool
register chan *websocket.Conn
Expand Down
Loading