diff --git a/node/pkg/dal/apiv2/controller.go b/node/pkg/dal/apiv2/controller.go index a5e1b96f5..cd69249e9 100644 --- a/node/pkg/dal/apiv2/controller.go +++ b/node/pkg/dal/apiv2/controller.go @@ -70,16 +70,22 @@ func NewServer(collector *collector.Collector, keyCache *keycache.KeyCache, hub collector: collector, keyCache: keyCache, hub: hub, - serveMux: http.NewServeMux(), } - s.serveMux.HandleFunc("/", s.HealthCheckHandler) - s.serveMux.HandleFunc("/ws", s.WSHandler) + serveMux := http.NewServeMux() - s.serveMux.HandleFunc("GET /symbols", s.SymbolsHandler) - s.serveMux.HandleFunc("GET /latest-data-feeds/all", s.AllLatestFeedsHandler) - s.serveMux.HandleFunc("GET /latest-data-feeds/transpose/all", s.AllLatestFeedsTransposedHandler) - s.serveMux.HandleFunc("GET /latest-data-feeds/transpose/{symbols}", s.TransposedLatestFeedsHandler) - s.serveMux.HandleFunc("GET /latest-data-feeds/{symbols}", s.LatestFeedsHandler) + serveMux.HandleFunc("/", s.HealthCheckHandler) + serveMux.HandleFunc("/ws", s.WSHandler) + + serveMux.HandleFunc("GET /symbols", s.SymbolsHandler) + serveMux.HandleFunc("GET /latest-data-feeds/all", s.AllLatestFeedsHandler) + serveMux.HandleFunc("GET /latest-data-feeds/transpose/all", s.AllLatestFeedsTransposedHandler) + serveMux.HandleFunc("GET /latest-data-feeds/transpose/{symbols}", s.TransposedLatestFeedsHandler) + serveMux.HandleFunc("GET /latest-data-feeds/{symbols}", s.LatestFeedsHandler) + + // Apply the RequestLoggerMiddleware to the ServeMux + loggedMux := stats.RequestLoggerMiddleware(serveMux) + + s.handler = loggedMux return s } @@ -96,7 +102,7 @@ func (s *ServerV2) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - s.serveMux.ServeHTTP(w, r) + s.handler.ServeHTTP(w, r) } func (s *ServerV2) checkAPIKey(ctx context.Context, key string) bool { diff --git a/node/pkg/dal/apiv2/types.go b/node/pkg/dal/apiv2/types.go index a772c506b..764b02214 100644 --- a/node/pkg/dal/apiv2/types.go +++ b/node/pkg/dal/apiv2/types.go @@ -21,7 +21,7 @@ type ServerV2 struct { collector *collector.Collector hub *hub.Hub keyCache *keycache.KeyCache - serveMux *http.ServeMux + handler http.Handler } type ServerV2Config struct { diff --git a/node/pkg/dal/utils/stats/stats.go b/node/pkg/dal/utils/stats/stats.go index 8116be806..c63359728 100644 --- a/node/pkg/dal/utils/stats/stats.go +++ b/node/pkg/dal/utils/stats/stats.go @@ -1,11 +1,15 @@ package stats import ( + "bufio" + "bytes" "context" + "errors" + "net" + "net/http" "time" "bisonai.com/orakl/node/pkg/db" - "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" ) @@ -67,35 +71,68 @@ func InsertWebsocketSubscriptions(ctx context.Context, connectionId int32, topic return db.BulkInsert(ctx, "websocket_subscriptions", []string{"connection_id", "topic"}, entries) } -func StatsMiddleware(c *fiber.Ctx) error { - start := time.Now() - if err := c.Next(); err != nil { - return err - } - duration := time.Since(start) +func RequestLoggerMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + sl := NewStatsLogger(w) + w.Header() + defer func() { + key := r.Header.Get("X-API-Key") + if key == "" { + log.Warn().Msg("X-API-Key header is empty") + return + } - if c.Path() == "/" { - return nil - } + endpoint := r.RequestURI + if endpoint == "/" { + return + } - headers := c.GetReqHeaders() - apiKeyRaw, ok := headers["X-Api-Key"] - if !ok { - log.Warn().Str("ip", c.IP()). - Str("method", c.Method()). - Str("path", c.Path()).Msg("X-Api-Key header not found") - return nil - } - apiKey := apiKeyRaw[0] - if apiKey == "" { - log.Warn().Msg("X-Api-Key header is empty") - return nil + statusCode := sl.statusCode + duration := time.Since(start) + if err := InsertRestCall(r.Context(), key, endpoint, *statusCode, duration); err != nil { + log.Error().Err(err).Msg("failed to insert rest call") + } + }() + next.ServeHTTP(sl, r) + }) +} + +type StatsLogger struct { + w *http.ResponseWriter + body *bytes.Buffer + statusCode *int +} + +func NewStatsLogger(w http.ResponseWriter) StatsLogger { + var buf bytes.Buffer + var statusCode int = 200 + return StatsLogger{ + w: &w, + body: &buf, + statusCode: &statusCode, } +} + +func (sl StatsLogger) Write(buf []byte) (int, error) { + sl.body.Write(buf) + return (*sl.w).Write(buf) +} - endpoint := c.Path() - statusCode := c.Response().StatusCode() - if err := InsertRestCall(c.Context(), apiKey, endpoint, statusCode, duration); err != nil { - log.Error().Err(err).Msg("failed to insert rest call") +func (sl StatsLogger) Header() http.Header { + return (*sl.w).Header() + +} + +func (sl StatsLogger) WriteHeader(statusCode int) { + (*sl.statusCode) = statusCode + (*sl.w).WriteHeader(statusCode) +} + +func (sl StatsLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := (*sl.w).(http.Hijacker) + if !ok { + return nil, nil, errors.New("hijack not supported") } - return nil + return h.Hijack() }