Skip to content

Commit

Permalink
feat: add caching option to fetcher
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Jan 18, 2024
1 parent 728f15d commit 93371db
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 28 deletions.
87 changes: 61 additions & 26 deletions fetcher/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@ package fetcher
import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
stderrors "errors"
"io"
"net/http"
"os"
"strings"
"time"

"github.com/dgraph-io/ristretto"
"github.com/hashicorp/go-retryablehttp"
"github.com/pkg/errors"

Expand All @@ -24,11 +27,15 @@ import (
type Fetcher struct {
hc *retryablehttp.Client
limit int64
cache *ristretto.Cache
ttl time.Duration
}

type opts struct {
hc *retryablehttp.Client
limit int64
cache *ristretto.Cache
ttl time.Duration
}

var ErrUnknownScheme = stderrors.New("unknown scheme")
Expand All @@ -48,6 +55,21 @@ func WithMaxHTTPMaxBytes(limit int64) Modifier {
}
}

func WithCache(maxBytes, avgSize int64, ttl time.Duration) Modifier {
return func(o *opts) {
if ttl < 0 || maxBytes <= 0 || avgSize >= maxBytes {
return
}
maxItems := maxBytes / avgSize
o.cache, _ = ristretto.NewCache(&ristretto.Config{
NumCounters: maxItems * 10,
MaxCost: maxBytes,
BufferItems: 64,
})
o.ttl = ttl
}
}

func newOpts() *opts {
return &opts{
hc: httpx.NewResilientClient(),
Expand All @@ -62,41 +84,58 @@ func NewFetcher(opts ...Modifier) *Fetcher {
for _, f := range opts {
f(o)
}
return &Fetcher{hc: o.hc, limit: o.limit}
return &Fetcher{hc: o.hc, limit: o.limit, cache: o.cache, ttl: o.ttl}
}

// Fetch fetches the file contents from the source.
func (f *Fetcher) Fetch(source string) (*bytes.Buffer, error) {
func (f *Fetcher) Fetch(source string) ([]byte, error) {
return f.FetchContext(context.Background(), source)
}

// FetchContext fetches the file contents from the source and allows to pass a
// context that is used for HTTP requests.
func (f *Fetcher) FetchContext(ctx context.Context, source string) (*bytes.Buffer, error) {
func (f *Fetcher) FetchContext(ctx context.Context, source string) ([]byte, error) {
switch s := stringsx.SwitchPrefix(source); {
case s.HasPrefix("http://"), s.HasPrefix("https://"):
return f.fetchRemote(ctx, source)
case s.HasPrefix("file://"):
return f.fetchFile(strings.Replace(source, "file://", "", 1))
return f.fetchFile(strings.TrimPrefix(source, "file://"))
case s.HasPrefix("base64://"):
src, err := base64.StdEncoding.DecodeString(strings.Replace(source, "base64://", "", 1))
src, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(source, "base64://"))
if err != nil {
return nil, errors.Wrapf(err, "rule: %s", source)
return nil, errors.Wrapf(err, "base64decode: %s", source)
}
return bytes.NewBuffer(src), nil
return src, nil
default:
return nil, errors.Wrap(ErrUnknownScheme, s.ToUnknownPrefixErr().Error())
}
}

func (f *Fetcher) fetchRemote(ctx context.Context, source string) (*bytes.Buffer, error) {
func (f *Fetcher) fetchRemote(ctx context.Context, source string) (b []byte, err error) {
if f.cache != nil {
cacheKey := sha256.Sum256([]byte(source))
if v, ok := f.cache.Get(cacheKey[:]); ok {
cached := v.([]byte)
b = make([]byte, len(cached))
copy(b, cached)
return b, nil
}
defer func() {
if err == nil && len(b) > 0 {
toCache := make([]byte, len(b))
copy(toCache, b)
f.cache.SetWithTTL(cacheKey[:], toCache, int64(len(toCache)), f.ttl)
}
}()
}

req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodGet, source, nil)
if err != nil {
return nil, errors.Wrapf(err, "rule: %s", source)
return nil, errors.Wrapf(err, "new request: %s", source)
}
res, err := f.hc.Do(req)
if err != nil {
return nil, errors.Wrapf(err, "rule: %s", source)
return nil, errors.Wrap(err, source)
}
defer res.Body.Close()

Expand All @@ -113,27 +152,23 @@ func (f *Fetcher) fetchRemote(ctx context.Context, source string) (*bytes.Buffer
if err != nil {
return nil, err
}
return &buf, nil
return buf.Bytes(), nil
}
return f.toBuffer(res.Body)
return io.ReadAll(res.Body)
}

func (f *Fetcher) fetchFile(source string) (*bytes.Buffer, error) {
func (f *Fetcher) fetchFile(source string) ([]byte, error) {
fp, err := os.Open(source) // #nosec:G304
if err != nil {
return nil, errors.Wrapf(err, "unable to fetch from source: %s", source)
return nil, errors.Wrapf(err, "unable to open file: %s", source)
}
defer func() {
_ = fp.Close()
}()

return f.toBuffer(fp)
}

func (f *Fetcher) toBuffer(r io.Reader) (*bytes.Buffer, error) {
var b bytes.Buffer
if _, err := io.Copy(&b, r); err != nil {
return nil, err
defer fp.Close()
b, err := io.ReadAll(fp)
if err != nil {
return nil, errors.Wrapf(err, "unable to read file: %s", source)
}
if err := fp.Close(); err != nil {
return nil, errors.Wrapf(err, "unable to close file: %s", source)
}
return &b, nil
return b, nil
}
33 changes: 32 additions & 1 deletion fetcher/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"net/http"
"os"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -61,7 +62,7 @@ func TestFetcher(t *testing.T) {
t.Run(fmt.Sprintf("config=%d/case=%d", fc, k), func(t *testing.T) {
actual, err := fetcher.Fetch(tc.source)
require.NoError(t, err)
assert.JSONEq(t, tc.expect, actual.String())
assert.JSONEq(t, tc.expect, string(actual))
})
}
}
Expand Down Expand Up @@ -93,4 +94,34 @@ func TestFetcher(t *testing.T) {
_, err = NewFetcher(WithMaxHTTPMaxBytes(4000)).Fetch(srv.URL)
assert.NoError(t, err)
})

t.Run("case=with-cache", func(t *testing.T) {
var hits int32
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("toodaloo"))
atomic.AddInt32(&hits, 1)
}))
t.Cleanup(srv.Close)

f := NewFetcher(WithCache(10000, 100, 0))

res, err := f.Fetch(srv.URL)
require.NoError(t, err)
require.Equal(t, "toodaloo", string(res))

require.EqualValues(t, 1, atomic.LoadInt32(&hits))

f.cache.Wait()

for i := 0; i < 100; i++ {
res2, err := f.Fetch(srv.URL)
require.NoError(t, err)
require.Equal(t, "toodaloo", string(res2))
if &res == &res2 {
t.Fatalf("cache should not return the same pointer")
}
}

require.EqualValues(t, 1, atomic.LoadInt32(&hits))
})
}
3 changes: 2 additions & 1 deletion jwksx/fetcher_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package jwksx

import (
"bytes"
"context"
"crypto/sha256"
"time"
Expand Down Expand Up @@ -156,7 +157,7 @@ func (f *FetcherNext) fetch(ctx context.Context, location string, opts *fetcherN
return nil, err
}

set, err := jwk.ParseReader(result)
set, err := jwk.ParseReader(bytes.NewBuffer(result))
if err != nil {
return nil, errors.WithStack(herodot.ErrBadRequest.WithReason("failed to parse JWK set").WithWrap(err))
}
Expand Down

0 comments on commit 93371db

Please sign in to comment.