Skip to content

Commit

Permalink
automod: extracting some packages, and concurrency safety (#464)
Browse files Browse the repository at this point in the history
A few minor hopefully-advancements to the automod/hepa packages:

- `countstore` is now extracted as a package.
- Synchronization primitives are added to MemCountStore. This fixes
crashes when running `hepa` without the use of Redis.
- Additional testing, to make sure that MemCountStore gets enough
exercise that the race detector has something to work on in tests.
- Attempted to sneak in some docs along the way!
  • Loading branch information
warpfork authored Dec 10, 2023
2 parents 1c05892 + ab96044 commit c54aff4
Show file tree
Hide file tree
Showing 17 changed files with 265 additions and 156 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ test-coverage.out
/sonar-cli
/stress
/supercollider
/hepa

# Don't ignore this file itself, or other specific dotfiles
!.gitignore
Expand Down
93 changes: 0 additions & 93 deletions automod/countstore.go

This file was deleted.

66 changes: 66 additions & 0 deletions automod/countstore/countstore.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package countstore

import (
"context"
"fmt"
"log/slog"
"time"
)

const (
PeriodTotal = "total"
PeriodDay = "day"
PeriodHour = "hour"
)

// CountStore is an interface for storing incrementing event counts, bucketed into periods.
// It is implemented by MemCountStore and by RedisCountStore.
//
// Period bucketing works on the basis of the current date (as determined mid-call).
// See the `Period*` consts for the available period types.
//
// The "GetCount" and "Increment" methods perform actual counting.
// The "*Distinct" methods have a different behavior:
// "IncrementDistinct" marks a value as seen at least once,
// and "GetCountDistinct" asks how _many_ values have been seen at least once.
//
// Incrementing -- both the "Increment" and "IncrementDistinct" variants -- increases
// a count in each supported period bucket size.
// In other words, one call to CountStore.Increment causes three increments internally:
// one to the count for the hour, one to the count for the day, and one to thte all-time count.
//
// The exact implementation and precision of the "*Distinct" methods may vary:
// in the MemCountStore implementation, it is precise (it's based on large maps);
// in the RedisCountStore implementation, it uses the Redis "pfcount" feature,
// which is based on a HyperLogLog datastructure which has probablistic properties
// (see https://redis.io/commands/pfcount/ ).
//
// Memory growth and availablity of information over time also varies by implementation.
// The RedisCountStore implementation uses Redis's key expiration primitives;
// only the all-time counts go without expiration.
// The MemCountStore grows without bound (it's intended to be used in testing
// and other non-production operations).
//
type CountStore interface {
GetCount(ctx context.Context, name, val, period string) (int, error)
Increment(ctx context.Context, name, val string) error
// TODO: batch increment method
GetCountDistinct(ctx context.Context, name, bucket, period string) (int, error)
IncrementDistinct(ctx context.Context, name, bucket, val string) error
}

func periodBucket(name, val, period string) string {
switch period {
case PeriodTotal:
return fmt.Sprintf("%s/%s", name, val)
case PeriodDay:
t := time.Now().UTC().Format(time.DateOnly)
return fmt.Sprintf("%s/%s/%s", name, val, t)
case PeriodHour:
t := time.Now().UTC().Format(time.RFC3339)[0:13]
return fmt.Sprintf("%s/%s/%s", name, val, t)
default:
slog.Warn("unhandled counter period", "period", period)
return fmt.Sprintf("%s/%s", name, val)
}
}
64 changes: 64 additions & 0 deletions automod/countstore/countstore_mem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package countstore

import (
"context"

"github.com/puzpuzpuz/xsync/v3"
)

type MemCountStore struct {
// Counts is keyed by a string that is a munge of "{name}/{val}[/{period}]",
// where period is either absent (meaning all-time total)
// or a string describing that timeperiod (either "YYYY-MM-DD" or that plus a literal "T" and "HH").
//
// (Using a values for `name` and `val` with slashes in them is perhaps inadvisable, as it may be ambiguous.)
Counts *xsync.MapOf[string, int]
DistinctCounts *xsync.MapOf[string, *xsync.MapOf[string, bool]]
}

func NewMemCountStore() MemCountStore {
return MemCountStore{
Counts: xsync.NewMapOf[string, int](),
DistinctCounts: xsync.NewMapOf[string, *xsync.MapOf[string, bool]](),
}
}

func (s MemCountStore) GetCount(ctx context.Context, name, val, period string) (int, error) {
v, ok := s.Counts.Load(periodBucket(name, val, period))
if !ok {
return 0, nil
}
return v, nil
}

func (s MemCountStore) Increment(ctx context.Context, name, val string) error {
for _, p := range []string{PeriodTotal, PeriodDay, PeriodHour} {
k := periodBucket(name, val, p)
s.Counts.Compute(k, func(oldVal int, _ bool) (int, bool) {
return oldVal+1, false
})
}
return nil
}

func (s MemCountStore) GetCountDistinct(ctx context.Context, name, bucket, period string) (int, error) {
v, ok := s.DistinctCounts.Load(periodBucket(name, bucket, period))
if !ok {
return 0, nil
}
return v.Size(), nil
}

func (s MemCountStore) IncrementDistinct(ctx context.Context, name, bucket, val string) error {
for _, p := range []string{PeriodTotal, PeriodDay, PeriodHour} {
k := periodBucket(name, bucket, p)
s.DistinctCounts.Compute(k,func(nested *xsync.MapOf[string, bool], _ bool) (*xsync.MapOf[string, bool], bool) {
if nested == nil {
nested = xsync.NewMapOf[string, bool]()
}
nested.Store(val, true)
return nested, false
})
}
return nil
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package automod
package countstore

import (
"context"
Expand Down Expand Up @@ -32,7 +32,7 @@ func NewRedisCountStore(redisURL string) (*RedisCountStore, error) {
}

func (s *RedisCountStore) GetCount(ctx context.Context, name, val, period string) (int, error) {
key := redisCountPrefix + PeriodBucket(name, val, period)
key := redisCountPrefix + periodBucket(name, val, period)
c, err := s.Client.Get(ctx, key).Int()
if err == redis.Nil {
return 0, nil
Expand All @@ -49,15 +49,15 @@ func (s *RedisCountStore) Increment(ctx context.Context, name, val string) error
// increment multiple counters in a single redis round-trip
multi := s.Client.Pipeline()

key = redisCountPrefix + PeriodBucket(name, val, PeriodHour)
key = redisCountPrefix + periodBucket(name, val, PeriodHour)
multi.Incr(ctx, key)
multi.Expire(ctx, key, 2*time.Hour)

key = redisCountPrefix + PeriodBucket(name, val, PeriodDay)
key = redisCountPrefix + periodBucket(name, val, PeriodDay)
multi.Incr(ctx, key)
multi.Expire(ctx, key, 48*time.Hour)

key = redisCountPrefix + PeriodBucket(name, val, PeriodTotal)
key = redisCountPrefix + periodBucket(name, val, PeriodTotal)
multi.Incr(ctx, key)
// no expiration for total

Expand All @@ -66,7 +66,7 @@ func (s *RedisCountStore) Increment(ctx context.Context, name, val string) error
}

func (s *RedisCountStore) GetCountDistinct(ctx context.Context, name, val, period string) (int, error) {
key := redisDistinctPrefix + PeriodBucket(name, val, period)
key := redisDistinctPrefix + periodBucket(name, val, period)
c, err := s.Client.PFCount(ctx, key).Result()
if err == redis.Nil {
return 0, nil
Expand All @@ -83,15 +83,15 @@ func (s *RedisCountStore) IncrementDistinct(ctx context.Context, name, bucket, v
// increment multiple counters in a single redis round-trip
multi := s.Client.Pipeline()

key = redisDistinctPrefix + PeriodBucket(name, bucket, PeriodHour)
key = redisDistinctPrefix + periodBucket(name, bucket, PeriodHour)
multi.PFAdd(ctx, key, val)
multi.Expire(ctx, key, 2*time.Hour)

key = redisDistinctPrefix + PeriodBucket(name, bucket, PeriodDay)
key = redisDistinctPrefix + periodBucket(name, bucket, PeriodDay)
multi.PFAdd(ctx, key, val)
multi.Expire(ctx, key, 48*time.Hour)

key = redisDistinctPrefix + PeriodBucket(name, bucket, PeriodTotal)
key = redisDistinctPrefix + periodBucket(name, bucket, PeriodTotal)
multi.PFAdd(ctx, key, val)
// no expiration for total

Expand Down
100 changes: 100 additions & 0 deletions automod/countstore/countstore_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package countstore

import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestMemCountStoreBasics(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()

cs := NewMemCountStore()

c, err := cs.GetCount(ctx, "test1", "val1", PeriodTotal)
assert.NoError(err)
assert.Equal(0, c)
assert.NoError(cs.Increment(ctx, "test1", "val1"))
assert.NoError(cs.Increment(ctx, "test1", "val1"))
c, err = cs.GetCount(ctx, "test1", "val1", PeriodTotal)
assert.NoError(err)
assert.Equal(2, c)

c, err = cs.GetCountDistinct(ctx, "test2", "val2", PeriodTotal)
assert.NoError(err)
assert.Equal(0, c)
assert.NoError(cs.IncrementDistinct(ctx, "test2", "val2", "one"))
assert.NoError(cs.IncrementDistinct(ctx, "test2", "val2", "one"))
assert.NoError(cs.IncrementDistinct(ctx, "test2", "val2", "one"))
c, err = cs.GetCountDistinct(ctx, "test2", "val2", PeriodTotal)
assert.NoError(err)
assert.Equal(1, c)

assert.NoError(cs.IncrementDistinct(ctx, "test2", "val2", "two"))
assert.NoError(cs.IncrementDistinct(ctx, "test2", "val2", "three"))
c, err = cs.GetCountDistinct(ctx, "test2", "val2", PeriodTotal)
assert.NoError(err)
assert.Equal(3, c)
}

func TestMemCountStoreConcurrent(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()

cs := NewMemCountStore()

c, err := cs.GetCount(ctx, "test1", "val1", PeriodTotal)
assert.NoError(err)
assert.Equal(0, c)

// Increment two different values from four different goroutines.
// Read from two more (don't assert values; just that there's no error,
// and no race (run this with `-race`!).
// A short sleep ensures the scheduler is yielded to, so that order is decently random,
// and reads are interleaved with writes.
var wg sync.WaitGroup
fnInc := func(name, val string, times int) {
for i := 0; i < times; i++ {
assert.NoError(cs.Increment(ctx, name, val))
assert.NoError(cs.IncrementDistinct(ctx, name, name, val))
time.Sleep(time.Nanosecond)
}
wg.Done()
}
fnRead := func(name, val string, times int) {
for i := 0; i < times; i++ {
_, err := cs.GetCount(ctx, name, val, PeriodTotal)
assert.NoError(err)
time.Sleep(time.Nanosecond)
}
}
wg.Add(4)
go fnInc("test1", "val1", 10)
go fnInc("test1", "val1", 10)
go fnRead("test1", "val1", 10)
go fnInc("test2", "val2", 6)
go fnInc("test2", "val2", 6)
go fnRead("test2", "val2", 6)
wg.Wait()

// One final read for each value after all writer routines are collected.
// This one should match a fixed value of the sum of all writes.
c, err = cs.GetCount(ctx, "test1", "val1", PeriodTotal)
assert.NoError(err)
assert.Equal(20, c)
c, err = cs.GetCount(ctx, "test2", "val2", PeriodTotal)
assert.NoError(err)
assert.Equal(12, c)

// And what of distinct counts? Those should be 1.
c, err = cs.GetCountDistinct(ctx, "test1", "test1", PeriodTotal)
assert.NoError(err)
assert.Equal(1, c)
c, err = cs.GetCountDistinct(ctx, "test2", "test2", PeriodTotal)
assert.NoError(err)
assert.Equal(1, c)
}
Loading

0 comments on commit c54aff4

Please sign in to comment.