From 2c35ed00ef2d519c91a24505159e827f549fe082 Mon Sep 17 00:00:00 2001 From: Nick <148735107+nick-bisonai@users.noreply.github.com> Date: Wed, 21 Aug 2024 12:11:50 +0900 Subject: [PATCH] (DAL) Bulk insert user stats (#2145) * feat: implement stats middleware * feat: bulk * fix: add missing thread * feat: statsApp implementation --- node/pkg/dal/apiv2/controller.go | 28 +++-- node/pkg/dal/apiv2/types.go | 10 +- node/pkg/dal/app.go | 6 +- node/pkg/dal/tests/main_test.go | 9 +- node/pkg/dal/utils/stats/stats.go | 196 +++++++++++++++++++++++++----- 5 files changed, 205 insertions(+), 44 deletions(-) diff --git a/node/pkg/dal/apiv2/controller.go b/node/pkg/dal/apiv2/controller.go index a5e1b96f5..b011ccf2f 100644 --- a/node/pkg/dal/apiv2/controller.go +++ b/node/pkg/dal/apiv2/controller.go @@ -49,7 +49,7 @@ func Start(ctx context.Context, opts ...ServerV2Option) error { return err } - wsServer := NewServer(config.Collector, config.KeyCache, config.Hub) + wsServer := NewServer(config.Collector, config.KeyCache, config.Hub, config.StatsApp) httpServer := &http.Server{ Handler: wsServer, BaseContext: func(_ net.Listener) context.Context { @@ -65,21 +65,27 @@ func Start(ctx context.Context, opts ...ServerV2Option) error { return nil } -func NewServer(collector *collector.Collector, keyCache *keycache.KeyCache, hub *hub.Hub) *ServerV2 { +func NewServer(collector *collector.Collector, keyCache *keycache.KeyCache, hub *hub.Hub, statsApp *stats.StatsApp) *ServerV2 { s := &ServerV2{ collector: collector, keyCache: keyCache, hub: hub, - serveMux: http.NewServeMux(), } - s.serveMux.HandleFunc("/", s.HealthCheckHandler) - s.serveMux.HandleFunc("/ws", s.WSHandler) + serveMux := http.NewServeMux() - s.serveMux.HandleFunc("GET /symbols", s.SymbolsHandler) - s.serveMux.HandleFunc("GET /latest-data-feeds/all", s.AllLatestFeedsHandler) - s.serveMux.HandleFunc("GET /latest-data-feeds/transpose/all", s.AllLatestFeedsTransposedHandler) - s.serveMux.HandleFunc("GET /latest-data-feeds/transpose/{symbols}", s.TransposedLatestFeedsHandler) - s.serveMux.HandleFunc("GET /latest-data-feeds/{symbols}", s.LatestFeedsHandler) + serveMux.HandleFunc("/", s.HealthCheckHandler) + serveMux.HandleFunc("/ws", s.WSHandler) + + serveMux.HandleFunc("GET /symbols", s.SymbolsHandler) + serveMux.HandleFunc("GET /latest-data-feeds/all", s.AllLatestFeedsHandler) + serveMux.HandleFunc("GET /latest-data-feeds/transpose/all", s.AllLatestFeedsTransposedHandler) + serveMux.HandleFunc("GET /latest-data-feeds/transpose/{symbols}", s.TransposedLatestFeedsHandler) + serveMux.HandleFunc("GET /latest-data-feeds/{symbols}", s.LatestFeedsHandler) + + // Apply the RequestLoggerMiddleware to the ServeMux + loggedMux := statsApp.RequestLoggerMiddleware(serveMux) + + s.handler = loggedMux return s } @@ -96,7 +102,7 @@ func (s *ServerV2) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - s.serveMux.ServeHTTP(w, r) + s.handler.ServeHTTP(w, r) } func (s *ServerV2) checkAPIKey(ctx context.Context, key string) bool { diff --git a/node/pkg/dal/apiv2/types.go b/node/pkg/dal/apiv2/types.go index a772c506b..2d2ced35f 100644 --- a/node/pkg/dal/apiv2/types.go +++ b/node/pkg/dal/apiv2/types.go @@ -6,6 +6,7 @@ import ( "bisonai.com/orakl/node/pkg/dal/collector" "bisonai.com/orakl/node/pkg/dal/hub" "bisonai.com/orakl/node/pkg/dal/utils/keycache" + "bisonai.com/orakl/node/pkg/dal/utils/stats" ) type BulkResponse struct { @@ -21,7 +22,7 @@ type ServerV2 struct { collector *collector.Collector hub *hub.Hub keyCache *keycache.KeyCache - serveMux *http.ServeMux + handler http.Handler } type ServerV2Config struct { @@ -29,6 +30,7 @@ type ServerV2Config struct { Collector *collector.Collector Hub *hub.Hub KeyCache *keycache.KeyCache + StatsApp *stats.StatsApp } type ServerV2Option func(*ServerV2Config) @@ -51,6 +53,12 @@ func WithHub(h *hub.Hub) ServerV2Option { } } +func WithStatsApp(s *stats.StatsApp) ServerV2Option { + return func(config *ServerV2Config) { + config.StatsApp = s + } +} + func WithKeyCache(k *keycache.KeyCache) ServerV2Option { return func(config *ServerV2Config) { config.KeyCache = k diff --git a/node/pkg/dal/app.go b/node/pkg/dal/app.go index 6e18e9dbd..7d18eeb04 100644 --- a/node/pkg/dal/app.go +++ b/node/pkg/dal/app.go @@ -11,6 +11,7 @@ import ( "bisonai.com/orakl/node/pkg/dal/collector" "bisonai.com/orakl/node/pkg/dal/hub" "bisonai.com/orakl/node/pkg/dal/utils/keycache" + "bisonai.com/orakl/node/pkg/dal/utils/stats" "bisonai.com/orakl/node/pkg/utils/request" "github.com/rs/zerolog/log" @@ -21,6 +22,9 @@ type Config = types.Config func Run(ctx context.Context) error { log.Debug().Msg("Starting DAL API server") + statsApp := stats.Start(ctx) + defer statsApp.Stop() + keyCache := keycache.NewAPIKeyCache(1 * time.Hour) keyCache.CleanupLoop(10 * time.Minute) @@ -45,7 +49,7 @@ func Run(ctx context.Context) error { hub := hub.HubSetup(ctx, configs) go hub.Start(ctx, collector) - err = apiv2.Start(ctx, apiv2.WithCollector(collector), apiv2.WithHub(hub), apiv2.WithKeyCache(keyCache)) + err = apiv2.Start(ctx, apiv2.WithCollector(collector), apiv2.WithHub(hub), apiv2.WithKeyCache(keyCache), apiv2.WithStatsApp(statsApp)) if err != nil { log.Error().Err(err).Msg("Failed to start DAL WS server") return err diff --git a/node/pkg/dal/tests/main_test.go b/node/pkg/dal/tests/main_test.go index 04f4ffa73..b748a891e 100644 --- a/node/pkg/dal/tests/main_test.go +++ b/node/pkg/dal/tests/main_test.go @@ -17,6 +17,7 @@ import ( "bisonai.com/orakl/node/pkg/dal/collector" "bisonai.com/orakl/node/pkg/dal/hub" "bisonai.com/orakl/node/pkg/dal/utils/keycache" + "bisonai.com/orakl/node/pkg/dal/utils/stats" "bisonai.com/orakl/node/pkg/db" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -32,6 +33,7 @@ type TestItems struct { MockAdmin *httptest.Server MockDal *httptest.Server ApiKey string + StatsApp *stats.StatsApp } func testPublishData(ctx context.Context, submissionData aggregator.SubmissionData) error { @@ -113,7 +115,10 @@ func setup(ctx context.Context) (func() error, *TestItems, error) { hub := hub.HubSetup(ctx, configs) go hub.Start(ctx, collector) - server := apiv2.NewServer(collector, keyCache, hub) + statsApp := stats.NewStatsApp(ctx, stats.WithBulkLogsCopyInterval(1*time.Second)) + go statsApp.Run(ctx) + + server := apiv2.NewServer(collector, keyCache, hub, statsApp) mockDal := httptest.NewServer(server) @@ -122,6 +127,7 @@ func setup(ctx context.Context) (func() error, *TestItems, error) { testItems.Controller = hub testItems.MockAdmin = mockAdminServer testItems.MockDal = mockDal + testItems.StatsApp = statsApp return cleanup(ctx, testItems), testItems, nil } @@ -129,6 +135,7 @@ func setup(ctx context.Context) (func() error, *TestItems, error) { func cleanup(ctx context.Context, testItems *TestItems) func() error { return func() error { testItems.MockDal.Close() + testItems.StatsApp.Stop() testItems.Collector.Stop() testItems.Controller = nil diff --git a/node/pkg/dal/utils/stats/stats.go b/node/pkg/dal/utils/stats/stats.go index 8116be806..413a6cf11 100644 --- a/node/pkg/dal/utils/stats/stats.go +++ b/node/pkg/dal/utils/stats/stats.go @@ -1,11 +1,15 @@ package stats import ( + "bufio" + "bytes" "context" + "errors" + "net" + "net/http" "time" "bisonai.com/orakl/node/pkg/db" - "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) @@ -27,12 +31,138 @@ const ( ` ) -type websocketId struct { +const ( + DefaultBulkLogsCopyInterval = 10 * time.Minute + DefaultBufferSize = 20000 +) + +type StatsAppConfig struct { + BulkLogsCopyInterval time.Duration + BufferSize int +} + +type StatsOption func(*StatsAppConfig) + +func WithBulkLogsCopyInterval(interval time.Duration) StatsOption { + return func(config *StatsAppConfig) { + config.BulkLogsCopyInterval = interval + } +} + +func WithBufferSize(size int) StatsOption { + return func(config *StatsAppConfig) { + config.BufferSize = size + } +} + +type StatsApp struct { + BulkLogsCopyInterval time.Duration + RestEntryBuffer chan *RestEntry + Cancel context.CancelFunc +} + +type WebsocketId struct { Id int32 `db:"id"` } -func InsertRestCall(ctx context.Context, apiKey string, endpoint string, statusCode int, responseTime time.Duration) error { +type RestEntry struct { + ApiKey string + Endpoint string + StatusCode int + ResponseTime time.Duration +} + +func NewStatsApp(ctx context.Context, opts ...StatsOption) *StatsApp { + _, cancel := context.WithCancel(ctx) + + config := &StatsAppConfig{ + BulkLogsCopyInterval: DefaultBulkLogsCopyInterval, + BufferSize: DefaultBufferSize, + } + + for _, opt := range opts { + opt(config) + } + return &StatsApp{ + BulkLogsCopyInterval: config.BulkLogsCopyInterval, + RestEntryBuffer: make(chan *RestEntry, config.BufferSize), + Cancel: cancel, + } +} + +func Start(ctx context.Context) *StatsApp { + app := NewStatsApp(ctx) + go app.Run(ctx) + return app +} + +func (a *StatsApp) Stop() { + a.Cancel() +} + +func (a *StatsApp) Run(ctx context.Context) { + ticker := time.NewTicker(a.BulkLogsCopyInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + bulkCopyEntries := [][]any{} + loop: + for { + select { + case entry := <-a.RestEntryBuffer: + bulkCopyEntries = append(bulkCopyEntries, []any{entry.ApiKey, entry.Endpoint, entry.StatusCode, entry.ResponseTime.Microseconds()}) + default: + break loop + } + } + + if len(bulkCopyEntries) > 0 { + _, err := db.BulkCopy(ctx, "rest_calls", []string{"api_key", "endpoint", "status_code", "response_time"}, bulkCopyEntries) + if err != nil { + log.Error().Err(err).Msg("failed to bulk copy rest calls") + } + } + } + } +} + +func (a *StatsApp) RequestLoggerMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + sl := NewStatsLogger(w) + w.Header() + defer func() { + key := r.Header.Get("X-API-Key") + if key == "" { + log.Warn().Msg("X-API-Key header is empty") + return + } + + endpoint := r.RequestURI + if endpoint == "/" { + return + } + + statusCode := sl.statusCode + responseTime := time.Since(start) + + a.RestEntryBuffer <- &RestEntry{ + ApiKey: key, + Endpoint: endpoint, + StatusCode: *statusCode, + ResponseTime: responseTime, + } + }() + next.ServeHTTP(sl, r) + }) +} + +func InsertRestCall(ctx context.Context, apiKey string, endpoint string, statusCode int, responseTime time.Duration) error { responseTimeMicro := responseTime.Microseconds() return db.QueryWithoutResult(ctx, INSERT_REST_CALLS, map[string]any{ "api_key": apiKey, @@ -43,7 +173,7 @@ func InsertRestCall(ctx context.Context, apiKey string, endpoint string, statusC } func InsertWebsocketConnection(ctx context.Context, apiKey string) (int32, error) { - result, err := db.QueryRow[websocketId](ctx, INSERT_WEBSOCKET_CONNECTIONS, map[string]any{ + result, err := db.QueryRow[WebsocketId](ctx, INSERT_WEBSOCKET_CONNECTIONS, map[string]any{ "api_key": apiKey, }) if err != nil { @@ -67,35 +197,41 @@ func InsertWebsocketSubscriptions(ctx context.Context, connectionId int32, topic return db.BulkInsert(ctx, "websocket_subscriptions", []string{"connection_id", "topic"}, entries) } -func StatsMiddleware(c *fiber.Ctx) error { - start := time.Now() - if err := c.Next(); err != nil { - return err - } - duration := time.Since(start) +type StatsLogger struct { + w *http.ResponseWriter + body *bytes.Buffer + statusCode *int +} - if c.Path() == "/" { - return nil +func NewStatsLogger(w http.ResponseWriter) StatsLogger { + var buf bytes.Buffer + var statusCode int = 200 + return StatsLogger{ + w: &w, + body: &buf, + statusCode: &statusCode, } +} - headers := c.GetReqHeaders() - apiKeyRaw, ok := headers["X-Api-Key"] - if !ok { - log.Warn().Str("ip", c.IP()). - Str("method", c.Method()). - Str("path", c.Path()).Msg("X-Api-Key header not found") - return nil - } - apiKey := apiKeyRaw[0] - if apiKey == "" { - log.Warn().Msg("X-Api-Key header is empty") - return nil - } +func (sl StatsLogger) Write(buf []byte) (int, error) { + sl.body.Write(buf) + return (*sl.w).Write(buf) +} - endpoint := c.Path() - statusCode := c.Response().StatusCode() - if err := InsertRestCall(c.Context(), apiKey, endpoint, statusCode, duration); err != nil { - log.Error().Err(err).Msg("failed to insert rest call") +func (sl StatsLogger) Header() http.Header { + return (*sl.w).Header() + +} + +func (sl StatsLogger) WriteHeader(statusCode int) { + (*sl.statusCode) = statusCode + (*sl.w).WriteHeader(statusCode) +} + +func (sl StatsLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := (*sl.w).(http.Hijacker) + if !ok { + return nil, nil, errors.New("hijack not supported") } - return nil + return h.Hijack() }