Skip to content

Commit

Permalink
(DAL) Bulk insert user stats (#2145)
Browse files Browse the repository at this point in the history
* feat: implement stats middleware

* feat: bulk

* fix: add missing thread

* feat: statsApp implementation
  • Loading branch information
nick-bisonai authored Aug 21, 2024
1 parent c31e079 commit 2c35ed0
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 44 deletions.
28 changes: 17 additions & 11 deletions node/pkg/dal/apiv2/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func Start(ctx context.Context, opts ...ServerV2Option) error {
return err
}

wsServer := NewServer(config.Collector, config.KeyCache, config.Hub)
wsServer := NewServer(config.Collector, config.KeyCache, config.Hub, config.StatsApp)
httpServer := &http.Server{
Handler: wsServer,
BaseContext: func(_ net.Listener) context.Context {
Expand All @@ -65,21 +65,27 @@ func Start(ctx context.Context, opts ...ServerV2Option) error {
return nil
}

func NewServer(collector *collector.Collector, keyCache *keycache.KeyCache, hub *hub.Hub) *ServerV2 {
func NewServer(collector *collector.Collector, keyCache *keycache.KeyCache, hub *hub.Hub, statsApp *stats.StatsApp) *ServerV2 {
s := &ServerV2{
collector: collector,
keyCache: keyCache,
hub: hub,
serveMux: http.NewServeMux(),
}
s.serveMux.HandleFunc("/", s.HealthCheckHandler)
s.serveMux.HandleFunc("/ws", s.WSHandler)
serveMux := http.NewServeMux()

s.serveMux.HandleFunc("GET /symbols", s.SymbolsHandler)
s.serveMux.HandleFunc("GET /latest-data-feeds/all", s.AllLatestFeedsHandler)
s.serveMux.HandleFunc("GET /latest-data-feeds/transpose/all", s.AllLatestFeedsTransposedHandler)
s.serveMux.HandleFunc("GET /latest-data-feeds/transpose/{symbols}", s.TransposedLatestFeedsHandler)
s.serveMux.HandleFunc("GET /latest-data-feeds/{symbols}", s.LatestFeedsHandler)
serveMux.HandleFunc("/", s.HealthCheckHandler)
serveMux.HandleFunc("/ws", s.WSHandler)

serveMux.HandleFunc("GET /symbols", s.SymbolsHandler)
serveMux.HandleFunc("GET /latest-data-feeds/all", s.AllLatestFeedsHandler)
serveMux.HandleFunc("GET /latest-data-feeds/transpose/all", s.AllLatestFeedsTransposedHandler)
serveMux.HandleFunc("GET /latest-data-feeds/transpose/{symbols}", s.TransposedLatestFeedsHandler)
serveMux.HandleFunc("GET /latest-data-feeds/{symbols}", s.LatestFeedsHandler)

// Apply the RequestLoggerMiddleware to the ServeMux
loggedMux := statsApp.RequestLoggerMiddleware(serveMux)

s.handler = loggedMux

return s
}
Expand All @@ -96,7 +102,7 @@ func (s *ServerV2) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}

s.serveMux.ServeHTTP(w, r)
s.handler.ServeHTTP(w, r)
}

func (s *ServerV2) checkAPIKey(ctx context.Context, key string) bool {
Expand Down
10 changes: 9 additions & 1 deletion node/pkg/dal/apiv2/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"bisonai.com/orakl/node/pkg/dal/collector"
"bisonai.com/orakl/node/pkg/dal/hub"
"bisonai.com/orakl/node/pkg/dal/utils/keycache"
"bisonai.com/orakl/node/pkg/dal/utils/stats"
)

type BulkResponse struct {
Expand All @@ -21,14 +22,15 @@ type ServerV2 struct {
collector *collector.Collector
hub *hub.Hub
keyCache *keycache.KeyCache
serveMux *http.ServeMux
handler http.Handler
}

type ServerV2Config struct {
Port string
Collector *collector.Collector
Hub *hub.Hub
KeyCache *keycache.KeyCache
StatsApp *stats.StatsApp
}

type ServerV2Option func(*ServerV2Config)
Expand All @@ -51,6 +53,12 @@ func WithHub(h *hub.Hub) ServerV2Option {
}
}

func WithStatsApp(s *stats.StatsApp) ServerV2Option {
return func(config *ServerV2Config) {
config.StatsApp = s
}
}

func WithKeyCache(k *keycache.KeyCache) ServerV2Option {
return func(config *ServerV2Config) {
config.KeyCache = k
Expand Down
6 changes: 5 additions & 1 deletion node/pkg/dal/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"bisonai.com/orakl/node/pkg/dal/collector"
"bisonai.com/orakl/node/pkg/dal/hub"
"bisonai.com/orakl/node/pkg/dal/utils/keycache"
"bisonai.com/orakl/node/pkg/dal/utils/stats"
"bisonai.com/orakl/node/pkg/utils/request"

"github.com/rs/zerolog/log"
Expand All @@ -21,6 +22,9 @@ type Config = types.Config
func Run(ctx context.Context) error {
log.Debug().Msg("Starting DAL API server")

statsApp := stats.Start(ctx)
defer statsApp.Stop()

keyCache := keycache.NewAPIKeyCache(1 * time.Hour)
keyCache.CleanupLoop(10 * time.Minute)

Expand All @@ -45,7 +49,7 @@ func Run(ctx context.Context) error {
hub := hub.HubSetup(ctx, configs)
go hub.Start(ctx, collector)

err = apiv2.Start(ctx, apiv2.WithCollector(collector), apiv2.WithHub(hub), apiv2.WithKeyCache(keyCache))
err = apiv2.Start(ctx, apiv2.WithCollector(collector), apiv2.WithHub(hub), apiv2.WithKeyCache(keyCache), apiv2.WithStatsApp(statsApp))
if err != nil {
log.Error().Err(err).Msg("Failed to start DAL WS server")
return err
Expand Down
9 changes: 8 additions & 1 deletion node/pkg/dal/tests/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"bisonai.com/orakl/node/pkg/dal/collector"
"bisonai.com/orakl/node/pkg/dal/hub"
"bisonai.com/orakl/node/pkg/dal/utils/keycache"
"bisonai.com/orakl/node/pkg/dal/utils/stats"
"bisonai.com/orakl/node/pkg/db"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
Expand All @@ -32,6 +33,7 @@ type TestItems struct {
MockAdmin *httptest.Server
MockDal *httptest.Server
ApiKey string
StatsApp *stats.StatsApp
}

func testPublishData(ctx context.Context, submissionData aggregator.SubmissionData) error {
Expand Down Expand Up @@ -113,7 +115,10 @@ func setup(ctx context.Context) (func() error, *TestItems, error) {
hub := hub.HubSetup(ctx, configs)
go hub.Start(ctx, collector)

server := apiv2.NewServer(collector, keyCache, hub)
statsApp := stats.NewStatsApp(ctx, stats.WithBulkLogsCopyInterval(1*time.Second))
go statsApp.Run(ctx)

server := apiv2.NewServer(collector, keyCache, hub, statsApp)

mockDal := httptest.NewServer(server)

Expand All @@ -122,13 +127,15 @@ func setup(ctx context.Context) (func() error, *TestItems, error) {
testItems.Controller = hub
testItems.MockAdmin = mockAdminServer
testItems.MockDal = mockDal
testItems.StatsApp = statsApp

return cleanup(ctx, testItems), testItems, nil
}

func cleanup(ctx context.Context, testItems *TestItems) func() error {
return func() error {
testItems.MockDal.Close()
testItems.StatsApp.Stop()

testItems.Collector.Stop()
testItems.Controller = nil
Expand Down
196 changes: 166 additions & 30 deletions node/pkg/dal/utils/stats/stats.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package stats

import (
"bufio"
"bytes"
"context"
"errors"
"net"
"net/http"
"time"

"bisonai.com/orakl/node/pkg/db"
"github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log"
)

Expand All @@ -27,12 +31,138 @@ const (
`
)

type websocketId struct {
const (
DefaultBulkLogsCopyInterval = 10 * time.Minute
DefaultBufferSize = 20000
)

type StatsAppConfig struct {
BulkLogsCopyInterval time.Duration
BufferSize int
}

type StatsOption func(*StatsAppConfig)

func WithBulkLogsCopyInterval(interval time.Duration) StatsOption {
return func(config *StatsAppConfig) {
config.BulkLogsCopyInterval = interval
}
}

func WithBufferSize(size int) StatsOption {
return func(config *StatsAppConfig) {
config.BufferSize = size
}
}

type StatsApp struct {
BulkLogsCopyInterval time.Duration
RestEntryBuffer chan *RestEntry
Cancel context.CancelFunc
}

type WebsocketId struct {
Id int32 `db:"id"`
}

func InsertRestCall(ctx context.Context, apiKey string, endpoint string, statusCode int, responseTime time.Duration) error {
type RestEntry struct {
ApiKey string
Endpoint string
StatusCode int
ResponseTime time.Duration
}

func NewStatsApp(ctx context.Context, opts ...StatsOption) *StatsApp {
_, cancel := context.WithCancel(ctx)

config := &StatsAppConfig{
BulkLogsCopyInterval: DefaultBulkLogsCopyInterval,
BufferSize: DefaultBufferSize,
}

for _, opt := range opts {
opt(config)
}

return &StatsApp{
BulkLogsCopyInterval: config.BulkLogsCopyInterval,
RestEntryBuffer: make(chan *RestEntry, config.BufferSize),
Cancel: cancel,
}
}

func Start(ctx context.Context) *StatsApp {
app := NewStatsApp(ctx)
go app.Run(ctx)
return app
}

func (a *StatsApp) Stop() {
a.Cancel()
}

func (a *StatsApp) Run(ctx context.Context) {
ticker := time.NewTicker(a.BulkLogsCopyInterval)
defer ticker.Stop()

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
bulkCopyEntries := [][]any{}
loop:
for {
select {
case entry := <-a.RestEntryBuffer:
bulkCopyEntries = append(bulkCopyEntries, []any{entry.ApiKey, entry.Endpoint, entry.StatusCode, entry.ResponseTime.Microseconds()})
default:
break loop
}
}

if len(bulkCopyEntries) > 0 {
_, err := db.BulkCopy(ctx, "rest_calls", []string{"api_key", "endpoint", "status_code", "response_time"}, bulkCopyEntries)
if err != nil {
log.Error().Err(err).Msg("failed to bulk copy rest calls")
}
}
}
}
}

func (a *StatsApp) RequestLoggerMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
sl := NewStatsLogger(w)
w.Header()
defer func() {
key := r.Header.Get("X-API-Key")
if key == "" {
log.Warn().Msg("X-API-Key header is empty")
return
}

endpoint := r.RequestURI
if endpoint == "/" {
return
}

statusCode := sl.statusCode
responseTime := time.Since(start)

a.RestEntryBuffer <- &RestEntry{
ApiKey: key,
Endpoint: endpoint,
StatusCode: *statusCode,
ResponseTime: responseTime,
}
}()
next.ServeHTTP(sl, r)
})
}

func InsertRestCall(ctx context.Context, apiKey string, endpoint string, statusCode int, responseTime time.Duration) error {
responseTimeMicro := responseTime.Microseconds()
return db.QueryWithoutResult(ctx, INSERT_REST_CALLS, map[string]any{
"api_key": apiKey,
Expand All @@ -43,7 +173,7 @@ func InsertRestCall(ctx context.Context, apiKey string, endpoint string, statusC
}

func InsertWebsocketConnection(ctx context.Context, apiKey string) (int32, error) {
result, err := db.QueryRow[websocketId](ctx, INSERT_WEBSOCKET_CONNECTIONS, map[string]any{
result, err := db.QueryRow[WebsocketId](ctx, INSERT_WEBSOCKET_CONNECTIONS, map[string]any{
"api_key": apiKey,
})
if err != nil {
Expand All @@ -67,35 +197,41 @@ func InsertWebsocketSubscriptions(ctx context.Context, connectionId int32, topic
return db.BulkInsert(ctx, "websocket_subscriptions", []string{"connection_id", "topic"}, entries)
}

func StatsMiddleware(c *fiber.Ctx) error {
start := time.Now()
if err := c.Next(); err != nil {
return err
}
duration := time.Since(start)
type StatsLogger struct {
w *http.ResponseWriter
body *bytes.Buffer
statusCode *int
}

if c.Path() == "/" {
return nil
func NewStatsLogger(w http.ResponseWriter) StatsLogger {
var buf bytes.Buffer
var statusCode int = 200
return StatsLogger{
w: &w,
body: &buf,
statusCode: &statusCode,
}
}

headers := c.GetReqHeaders()
apiKeyRaw, ok := headers["X-Api-Key"]
if !ok {
log.Warn().Str("ip", c.IP()).
Str("method", c.Method()).
Str("path", c.Path()).Msg("X-Api-Key header not found")
return nil
}
apiKey := apiKeyRaw[0]
if apiKey == "" {
log.Warn().Msg("X-Api-Key header is empty")
return nil
}
func (sl StatsLogger) Write(buf []byte) (int, error) {
sl.body.Write(buf)
return (*sl.w).Write(buf)
}

endpoint := c.Path()
statusCode := c.Response().StatusCode()
if err := InsertRestCall(c.Context(), apiKey, endpoint, statusCode, duration); err != nil {
log.Error().Err(err).Msg("failed to insert rest call")
func (sl StatsLogger) Header() http.Header {
return (*sl.w).Header()

}

func (sl StatsLogger) WriteHeader(statusCode int) {
(*sl.statusCode) = statusCode
(*sl.w).WriteHeader(statusCode)
}

func (sl StatsLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h, ok := (*sl.w).(http.Hijacker)
if !ok {
return nil, nil, errors.New("hijack not supported")
}
return nil
return h.Hijack()
}

0 comments on commit 2c35ed0

Please sign in to comment.