From 286e58caa4b62377423d15bd8af147df07858d78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1ximo=20Cuadros?= Date: Fri, 21 May 2021 03:56:49 +0200 Subject: [PATCH] Config.StatusCode, ignore non 202 or 404 response --- cache.go | 20 ++++++++++++++++++-- cache_test.go | 19 ++++++++++++++++++- response.go | 1 - 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/cache.go b/cache.go index 30eb9f5..828f06b 100644 --- a/cache.go +++ b/cache.go @@ -16,6 +16,8 @@ type Config struct { TTL time.Duration `default:"1m"` // Methods methods to be cached. Methods []string `default:"[GET]"` + // StatusCode method to be cached. + StatusCode []int `default:"[200,404]"` // IgnoreQuery if true the Query values from the requests are ignored on // the key generation. IgnoreQuery bool @@ -32,7 +34,6 @@ func New(cfg *Config, cache *freecache.Cache) echo.MiddlewareFunc { } defaults.SetDefaults(cfg) - m := &CacheMiddleware{cfg: cfg, cache: cache} return m.Handler } @@ -93,14 +94,29 @@ func (m *CacheMiddleware) readCache(key []byte, c echo.Context) error { } func (m *CacheMiddleware) cacheResult(key []byte, r *ResponseRecorder) error { - b, err := r.Result().Encode() + e := r.Result() + b, err := e.Encode() if err != nil { return fmt.Errorf("unable to read recorded response: %s", err) } + if !m.isStatusCacheable(e) { + return nil + } + return m.cache.Set(key, b, int(m.cfg.TTL.Seconds())) } +func (m *CacheMiddleware) isStatusCacheable(e *CacheEntry) bool { + for _, status := range m.cfg.StatusCode { + if e.StatusCode == status { + return true + } + } + + return false +} + func (m *CacheMiddleware) isCacheable(r *http.Request) bool { if m.cfg.Cache != nil { return m.cfg.Cache(r) diff --git a/cache_test.go b/cache_test.go index 7c34110..e26c79d 100644 --- a/cache_test.go +++ b/cache_test.go @@ -136,13 +136,30 @@ func TestCache_Methods(t *testing.T) { assertRequest(t, resp, http.StatusOK, "test_4") } +func TestCache_StatusCode(t *testing.T) { + client := getCachedServerWithCode(t, &Config{StatusCode: []int{200, 404}}, http.StatusInternalServerError) + defer client.Close() + + resp, err := http.Get(client.URL) + assert.NoError(t, err) + assertRequest(t, resp, http.StatusOK, "test_1") + + resp, err = http.Get(client.URL) + assert.NoError(t, err) + assertRequest(t, resp, http.StatusOK, "test_2") +} + func getCachedServer(t *testing.T, cfg *Config) *httptest.Server { + return getCachedServerWithCode(t, cfg, http.StatusOK) +} + +func getCachedServerWithCode(t *testing.T, cfg *Config, status int) *httptest.Server { e := echo.New() var i int h := New(cfg, freecache.NewCache(42*1024*1024))(func(c echo.Context) error { i++ - return c.String(http.StatusOK, fmt.Sprintf("test_%d", i)) + return c.String(status, fmt.Sprintf("test_%d", i)) }) return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/response.go b/response.go index 4153ecd..bb62a10 100644 --- a/response.go +++ b/response.go @@ -50,7 +50,6 @@ func (w *ResponseRecorder) WriteHeader(statusCode int) { func (r *ResponseRecorder) Result() *CacheEntry { r.copyHeaders() - r.ResponseWriter = nil return &CacheEntry{ Header: r.headers,