From 8c9b61e956e0eac338c4394bc25a50772a574075 Mon Sep 17 00:00:00 2001 From: MohammadReza Palide Date: Fri, 3 Nov 2023 05:40:20 +0000 Subject: [PATCH] fix race condition isssue | fix ipLimit value on tests --- internal/dmsg-discovery/api/api.go | 2 +- internal/dmsg-discovery/store/testing.go | 16 ++++++++-------- pkg/dmsg/server.go | 10 ++++++++-- pkg/dmsg/stream_test.go | 1 + pkg/dmsghttp/examples_test.go | 1 + pkg/dmsghttp/http_transport_test.go | 1 + pkg/dmsgtest/env.go | 1 + 7 files changed, 21 insertions(+), 11 deletions(-) diff --git a/internal/dmsg-discovery/api/api.go b/internal/dmsg-discovery/api/api.go index 837394779..8ac21b3fc 100644 --- a/internal/dmsg-discovery/api/api.go +++ b/internal/dmsg-discovery/api/api.go @@ -110,7 +110,7 @@ func (a *API) RunBackgroundTasks(ctx context.Context, log logrus.FieldLogger) { } // AllServers is used to get all the available servers registered to the dmsg-discovery. -func (a *API) AllServers(ctx context.Context, log logrus.FieldLogger) (entries []*disc.Entry, err error) { +func (a *API) AllServers(ctx context.Context, _ logrus.FieldLogger) (entries []*disc.Entry, err error) { entries, err = a.db.AllServers(ctx) if err != nil { return entries, err diff --git a/internal/dmsg-discovery/store/testing.go b/internal/dmsg-discovery/store/testing.go index ba8c6ba97..69c5e4c63 100644 --- a/internal/dmsg-discovery/store/testing.go +++ b/internal/dmsg-discovery/store/testing.go @@ -57,7 +57,7 @@ func NewMock() Storer { } // Entry implements Storer Entry method for MockStore -func (ms *MockStore) Entry(ctx context.Context, staticPubKey cipher.PubKey) (*disc.Entry, error) { +func (ms *MockStore) Entry(_ context.Context, staticPubKey cipher.PubKey) (*disc.Entry, error) { payload, ok := ms.entry(staticPubKey.Hex()) if !ok { return nil, disc.ErrKeyNotFound @@ -80,7 +80,7 @@ func (ms *MockStore) Entry(ctx context.Context, staticPubKey cipher.PubKey) (*di } // SetEntry implements Storer SetEntry method for MockStore -func (ms *MockStore) SetEntry(ctx context.Context, entry *disc.Entry, timeout time.Duration) error { +func (ms *MockStore) SetEntry(_ context.Context, entry *disc.Entry, _ time.Duration) error { payload, err := json.Marshal(entry) if err != nil { return disc.ErrUnexpected @@ -96,13 +96,13 @@ func (ms *MockStore) SetEntry(ctx context.Context, entry *disc.Entry, timeout ti } // DelEntry implements Storer DelEntry method for MockStore -func (ms *MockStore) DelEntry(ctx context.Context, staticPubKey cipher.PubKey) error { +func (ms *MockStore) DelEntry(_ context.Context, staticPubKey cipher.PubKey) error { ms.delEntry(staticPubKey.Hex()) return nil } // RemoveOldServerEntries implements Storer RemoveOldServerEntries method for MockStore -func (ms *MockStore) RemoveOldServerEntries(ctx context.Context) error { +func (ms *MockStore) RemoveOldServerEntries(_ context.Context) error { return nil } @@ -113,7 +113,7 @@ func (ms *MockStore) Clear() { } // AvailableServers implements Storer AvailableServers method for MockStore -func (ms *MockStore) AvailableServers(ctx context.Context, maxCount int) ([]*disc.Entry, error) { +func (ms *MockStore) AvailableServers(_ context.Context, _ int) ([]*disc.Entry, error) { entries := make([]*disc.Entry, 0) ms.serversLock.RLock() @@ -135,7 +135,7 @@ func (ms *MockStore) AvailableServers(ctx context.Context, maxCount int) ([]*dis } // AllServers implements Storer AllServers method for MockStore -func (ms *MockStore) AllServers(ctx context.Context) ([]*disc.Entry, error) { +func (ms *MockStore) AllServers(_ context.Context) ([]*disc.Entry, error) { entries := make([]*disc.Entry, 0) ms.serversLock.RLock() @@ -157,7 +157,7 @@ func (ms *MockStore) AllServers(ctx context.Context) ([]*disc.Entry, error) { } // CountEntries implements Storer CountEntries method for MockStore -func (ms *MockStore) CountEntries(ctx context.Context) (int64, int64, error) { +func (ms *MockStore) CountEntries(_ context.Context) (int64, int64, error) { var numberOfServers int64 var numberOfClients int64 ms.serversLock.RLock() @@ -198,7 +198,7 @@ func arrayFromMap(m map[string][]byte) [][]byte { } // AllEntries implements Storer CountEntries method for MockStore -func (ms *MockStore) AllEntries(ctx context.Context) ([]string, error) { +func (ms *MockStore) AllEntries(_ context.Context) ([]string, error) { entries := []string{} ms.mLock.RLock() diff --git a/pkg/dmsg/server.go b/pkg/dmsg/server.go index 8f659c00e..af40b75a2 100644 --- a/pkg/dmsg/server.go +++ b/pkg/dmsg/server.go @@ -51,8 +51,9 @@ type Server struct { maxSessions int - limitIP int - ipCounter map[string]int + limitIP int + ipCounter map[string]int + ipCounterLocker sync.RWMutex } // NewServer creates a new dmsg server entity. @@ -158,15 +159,20 @@ func (s *Server) Serve(lis net.Listener, addr string) error { Debug("Max sessions is reached, but still accepting so clients who delegated us can still listen.") } connIP := strings.Split(conn.RemoteAddr().String(), ":")[0] + s.ipCounterLocker.Lock() if s.ipCounter[connIP] >= s.limitIP { log.Warnf("Maximum client per IP for %s reached.", connIP) + s.ipCounterLocker.Unlock() continue } s.ipCounter[connIP]++ + s.ipCounterLocker.Unlock() s.wg.Add(1) go func(conn net.Conn) { defer func() { + s.ipCounterLocker.Lock() s.ipCounter[connIP]-- + s.ipCounterLocker.Unlock() err := recover() if err != nil { log.Warnf("panic in handleSession: %+v", err) diff --git a/pkg/dmsg/stream_test.go b/pkg/dmsg/stream_test.go index c1081dffe..2356f4b08 100644 --- a/pkg/dmsg/stream_test.go +++ b/pkg/dmsg/stream_test.go @@ -29,6 +29,7 @@ func TestStream(t *testing.T) { srvConf := &ServerConfig{ MaxSessions: maxSessions, UpdateInterval: 0, + LimitIP: 200, } srv := NewServer(pkSrv, skSrv, dc, srvConf, nil) srv.SetLogger(logging.MustGetLogger("server")) diff --git a/pkg/dmsghttp/examples_test.go b/pkg/dmsghttp/examples_test.go index df9493137..010bd4855 100644 --- a/pkg/dmsghttp/examples_test.go +++ b/pkg/dmsghttp/examples_test.go @@ -33,6 +33,7 @@ func ExampleMakeHTTPTransport() { srvConf := dmsg.ServerConfig{ MaxSessions: maxSessions, UpdateInterval: 0, + LimitIP: 200, } srv := dmsg.NewServer(srvPK, srvSK, dc, &srvConf, nil) defer func() { diff --git a/pkg/dmsghttp/http_transport_test.go b/pkg/dmsghttp/http_transport_test.go index a7edcd3a6..4e4364cef 100644 --- a/pkg/dmsghttp/http_transport_test.go +++ b/pkg/dmsghttp/http_transport_test.go @@ -107,6 +107,7 @@ func startDmsgEnv(t *testing.T, nSrvs, maxSessions int) disc.APIClient { conf := dmsg.ServerConfig{ MaxSessions: maxSessions, UpdateInterval: 0, + LimitIP: 200, } srv := dmsg.NewServer(pk, sk, dc, &conf, nil) srv.SetLogger(logging.MustGetLogger(fmt.Sprintf("server_%d", i))) diff --git a/pkg/dmsgtest/env.go b/pkg/dmsgtest/env.go index 9a294ea84..97520be36 100644 --- a/pkg/dmsgtest/env.go +++ b/pkg/dmsgtest/env.go @@ -90,6 +90,7 @@ func (env *Env) newServer(ctx context.Context, updateInterval time.Duration) (*d conf := dmsg.ServerConfig{ MaxSessions: maxSessions, UpdateInterval: updateInterval, + LimitIP: 200, } srv := dmsg.NewServer(pk, sk, env.d, &conf, nil) env.s[pk] = srv