Skip to content

Commit

Permalink
fix race condition isssue | fix ipLimit value on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mrpalide committed Nov 3, 2023
1 parent d93a6ff commit 8c9b61e
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 11 deletions.
2 changes: 1 addition & 1 deletion internal/dmsg-discovery/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (a *API) RunBackgroundTasks(ctx context.Context, log logrus.FieldLogger) {
}

// AllServers is used to get all the available servers registered to the dmsg-discovery.
func (a *API) AllServers(ctx context.Context, log logrus.FieldLogger) (entries []*disc.Entry, err error) {
func (a *API) AllServers(ctx context.Context, _ logrus.FieldLogger) (entries []*disc.Entry, err error) {
entries, err = a.db.AllServers(ctx)
if err != nil {
return entries, err
Expand Down
16 changes: 8 additions & 8 deletions internal/dmsg-discovery/store/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func NewMock() Storer {
}

// Entry implements Storer Entry method for MockStore
func (ms *MockStore) Entry(ctx context.Context, staticPubKey cipher.PubKey) (*disc.Entry, error) {
func (ms *MockStore) Entry(_ context.Context, staticPubKey cipher.PubKey) (*disc.Entry, error) {
payload, ok := ms.entry(staticPubKey.Hex())
if !ok {
return nil, disc.ErrKeyNotFound
Expand All @@ -80,7 +80,7 @@ func (ms *MockStore) Entry(ctx context.Context, staticPubKey cipher.PubKey) (*di
}

// SetEntry implements Storer SetEntry method for MockStore
func (ms *MockStore) SetEntry(ctx context.Context, entry *disc.Entry, timeout time.Duration) error {
func (ms *MockStore) SetEntry(_ context.Context, entry *disc.Entry, _ time.Duration) error {
payload, err := json.Marshal(entry)
if err != nil {
return disc.ErrUnexpected
Expand All @@ -96,13 +96,13 @@ func (ms *MockStore) SetEntry(ctx context.Context, entry *disc.Entry, timeout ti
}

// DelEntry implements Storer DelEntry method for MockStore
func (ms *MockStore) DelEntry(ctx context.Context, staticPubKey cipher.PubKey) error {
func (ms *MockStore) DelEntry(_ context.Context, staticPubKey cipher.PubKey) error {
ms.delEntry(staticPubKey.Hex())
return nil
}

// RemoveOldServerEntries implements Storer RemoveOldServerEntries method for MockStore
func (ms *MockStore) RemoveOldServerEntries(ctx context.Context) error {
func (ms *MockStore) RemoveOldServerEntries(_ context.Context) error {
return nil
}

Expand All @@ -113,7 +113,7 @@ func (ms *MockStore) Clear() {
}

// AvailableServers implements Storer AvailableServers method for MockStore
func (ms *MockStore) AvailableServers(ctx context.Context, maxCount int) ([]*disc.Entry, error) {
func (ms *MockStore) AvailableServers(_ context.Context, _ int) ([]*disc.Entry, error) {
entries := make([]*disc.Entry, 0)

ms.serversLock.RLock()
Expand All @@ -135,7 +135,7 @@ func (ms *MockStore) AvailableServers(ctx context.Context, maxCount int) ([]*dis
}

// AllServers implements Storer AllServers method for MockStore
func (ms *MockStore) AllServers(ctx context.Context) ([]*disc.Entry, error) {
func (ms *MockStore) AllServers(_ context.Context) ([]*disc.Entry, error) {
entries := make([]*disc.Entry, 0)

ms.serversLock.RLock()
Expand All @@ -157,7 +157,7 @@ func (ms *MockStore) AllServers(ctx context.Context) ([]*disc.Entry, error) {
}

// CountEntries implements Storer CountEntries method for MockStore
func (ms *MockStore) CountEntries(ctx context.Context) (int64, int64, error) {
func (ms *MockStore) CountEntries(_ context.Context) (int64, int64, error) {
var numberOfServers int64
var numberOfClients int64
ms.serversLock.RLock()
Expand Down Expand Up @@ -198,7 +198,7 @@ func arrayFromMap(m map[string][]byte) [][]byte {
}

// AllEntries implements Storer CountEntries method for MockStore
func (ms *MockStore) AllEntries(ctx context.Context) ([]string, error) {
func (ms *MockStore) AllEntries(_ context.Context) ([]string, error) {
entries := []string{}

ms.mLock.RLock()
Expand Down
10 changes: 8 additions & 2 deletions pkg/dmsg/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ type Server struct {

maxSessions int

limitIP int
ipCounter map[string]int
limitIP int
ipCounter map[string]int
ipCounterLocker sync.RWMutex
}

// NewServer creates a new dmsg server entity.
Expand Down Expand Up @@ -158,15 +159,20 @@ func (s *Server) Serve(lis net.Listener, addr string) error {
Debug("Max sessions is reached, but still accepting so clients who delegated us can still listen.")
}
connIP := strings.Split(conn.RemoteAddr().String(), ":")[0]
s.ipCounterLocker.Lock()
if s.ipCounter[connIP] >= s.limitIP {
log.Warnf("Maximum client per IP for %s reached.", connIP)
s.ipCounterLocker.Unlock()
continue
}
s.ipCounter[connIP]++
s.ipCounterLocker.Unlock()
s.wg.Add(1)
go func(conn net.Conn) {
defer func() {
s.ipCounterLocker.Lock()
s.ipCounter[connIP]--
s.ipCounterLocker.Unlock()
err := recover()
if err != nil {
log.Warnf("panic in handleSession: %+v", err)
Expand Down
1 change: 1 addition & 0 deletions pkg/dmsg/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ func TestStream(t *testing.T) {
srvConf := &ServerConfig{
MaxSessions: maxSessions,
UpdateInterval: 0,
LimitIP: 200,
}
srv := NewServer(pkSrv, skSrv, dc, srvConf, nil)
srv.SetLogger(logging.MustGetLogger("server"))
Expand Down
1 change: 1 addition & 0 deletions pkg/dmsghttp/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func ExampleMakeHTTPTransport() {
srvConf := dmsg.ServerConfig{
MaxSessions: maxSessions,
UpdateInterval: 0,
LimitIP: 200,
}
srv := dmsg.NewServer(srvPK, srvSK, dc, &srvConf, nil)
defer func() {
Expand Down
1 change: 1 addition & 0 deletions pkg/dmsghttp/http_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ func startDmsgEnv(t *testing.T, nSrvs, maxSessions int) disc.APIClient {
conf := dmsg.ServerConfig{
MaxSessions: maxSessions,
UpdateInterval: 0,
LimitIP: 200,
}
srv := dmsg.NewServer(pk, sk, dc, &conf, nil)
srv.SetLogger(logging.MustGetLogger(fmt.Sprintf("server_%d", i)))
Expand Down
1 change: 1 addition & 0 deletions pkg/dmsgtest/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ func (env *Env) newServer(ctx context.Context, updateInterval time.Duration) (*d
conf := dmsg.ServerConfig{
MaxSessions: maxSessions,
UpdateInterval: updateInterval,
LimitIP: 200,
}
srv := dmsg.NewServer(pk, sk, env.d, &conf, nil)
env.s[pk] = srv
Expand Down

0 comments on commit 8c9b61e

Please sign in to comment.