Skip to content

Commit

Permalink
Fix some leaks in authorize chain elements (networkservicemesh#1623)
Browse files Browse the repository at this point in the history
* fix some leaks in authorize chain elements

Signed-off-by: NikitaSkrynnik <[email protected]>

* fix go linter issues

Signed-off-by: NikitaSkrynnik <[email protected]>

* add tests for memory leaks

Signed-off-by: NikitaSkrynnik <[email protected]>

* fix go linter issues

Signed-off-by: NikitaSkrynnik <[email protected]>

* rerun CI

Signed-off-by: NikitaSkrynnik <[email protected]>

---------

Signed-off-by: NikitaSkrynnik <[email protected]>
  • Loading branch information
NikitaSkrynnik authored May 16, 2024
1 parent c2a3414 commit 3b79590
Show file tree
Hide file tree
Showing 6 changed files with 297 additions and 10 deletions.
13 changes: 10 additions & 3 deletions pkg/networkservice/common/authorize/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020-2022 Doc.ai and/or its affiliates.
//
// Copyright (c) 2020-2023 Cisco Systems, Inc.
// Copyright (c) 2020-2024 Cisco Systems, Inc.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -83,7 +83,8 @@ func (a *authorizeServer) Request(ctx context.Context, request *networkservice.N
}
}

if spiffeID, err := spire.PeerSpiffeIDFromContext(ctx); err == nil {
spiffeID, loadErr := spire.PeerSpiffeIDFromContext(ctx)
if loadErr == nil {
connID := conn.GetPath().GetPathSegments()[index-1].GetId()
ids, ok := a.spiffeIDConnectionMap.Load(spiffeID)
if !ok {
Expand All @@ -92,7 +93,13 @@ func (a *authorizeServer) Request(ctx context.Context, request *networkservice.N
ids.Store(connID, struct{}{})
a.spiffeIDConnectionMap.Store(spiffeID, ids)
}
return next.Server(ctx).Request(ctx, request)

conn, err := next.Server(ctx).Request(ctx, request)
if loadErr == nil && err != nil {
a.spiffeIDConnectionMap.Delete(spiffeID)
}

return conn, err
}

func (a *authorizeServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) {
Expand Down
102 changes: 101 additions & 1 deletion pkg/networkservice/common/authorize/server_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2020-2021 Doc.ai and/or its affiliates.
//
// Copyright (c) 2022-2023 Cisco and/or its affiliates.
// Copyright (c) 2022-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -33,8 +33,13 @@ import (
"testing"
"time"

mathrand "math/rand"

"github.com/edwarnicke/genericsync"
"github.com/golang/protobuf/ptypes/empty"
"github.com/networkservicemesh/api/pkg/api/networkservice"
"github.com/pkg/errors"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/grpc/codes"
Expand All @@ -43,6 +48,8 @@ import (
"google.golang.org/grpc/status"

"github.com/networkservicemesh/sdk/pkg/networkservice/common/authorize"
"github.com/networkservicemesh/sdk/pkg/networkservice/core/next"
"github.com/networkservicemesh/sdk/pkg/tools/nanoid"
)

func generateCert(u *url.URL) []byte {
Expand Down Expand Up @@ -208,3 +215,96 @@ func TestAuthorize_EmptySpiffeIDConnectionMapOnClose(t *testing.T) {
_, err = server.Close(ctx, conn)
require.NoError(t, err)
}

type randomErrorServer struct {
errorChance float32
}

// NewServer returns a server chain element returning error on Close/Request on given times
func NewServer(errorChance float32) networkservice.NetworkServiceServer {
return &randomErrorServer{errorChance: errorChance}
}

func (s randomErrorServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) {
// nolint
val := mathrand.Float32()
if val > s.errorChance {
return nil, errors.New("random error")
}
return next.Server(ctx).Request(ctx, request)
}

func (s randomErrorServer) Close(ctx context.Context, conn *networkservice.Connection) (*empty.Empty, error) {
return next.Server(ctx).Close(ctx, conn)
}

type closeData struct {
conn *networkservice.Connection
cert []byte
}

func TestAuthorize_SpiffeIDConnectionMapHaveNoLeaks(t *testing.T) {
dir := filepath.Clean(path.Join(os.TempDir(), t.Name()))
defer func() {
_ = os.RemoveAll(dir)
}()

err := os.MkdirAll(dir, os.ModePerm)
require.Nil(t, err)

policyPath := filepath.Clean(path.Join(dir, "policy.rego"))
err = os.WriteFile(policyPath, []byte(testPolicy()), os.ModePerm)
require.Nil(t, err)

var authorizeMap genericsync.Map[spiffeid.ID, *genericsync.Map[string, struct{}]]
chain := next.NewNetworkServiceServer(
authorize.NewServer(authorize.WithPolicies(policyPath), authorize.WithSpiffeIDConnectionMap(&authorizeMap)),
NewServer(0.2),
)

request := &networkservice.NetworkServiceRequest{
Connection: &networkservice.Connection{
Id: "id",
Path: &networkservice.Path{
Index: 2,
PathSegments: []*networkservice.PathSegment{
{Id: "client", Name: "client", Token: "allowed"},
{Id: "nsmgr", Name: "nsmgr", Token: "allowed"},
{Id: "forwarder", Name: "forwarder", Token: "allowed"},
},
},
},
}

// Make 1000 requests with random spiffe IDs
count := 1000
data := make([]closeData, 0)
for i := 0; i < count; i++ {
spiffeidPath, err := nanoid.GenerateString(10, nanoid.WithAlphabet("abcdefghijklmnopqrstuvwxyz"))
require.NoError(t, err)

certBytes := generateCert(&url.URL{Scheme: "spiffe", Host: "test.com", Path: spiffeidPath})
ctx, err := withPeer(context.Background(), certBytes)
require.NoError(t, err)

conn, err := chain.Request(ctx, request)
if err == nil {
data = append(data, closeData{conn: conn, cert: certBytes})
}
}

// Close the connections established in the previous loop
for _, closeData := range data {
ctx, err := withPeer(context.Background(), closeData.cert)
require.NoError(t, err)
_, err = chain.Close(ctx, closeData.conn)
require.NoError(t, err)
}

mapLen := 0
authorizeMap.Range(func(key spiffeid.ID, value *genericsync.Map[string, struct{}]) bool {
mapLen++
return true
})
require.Equal(t, mapLen, 0)
}
8 changes: 6 additions & 2 deletions pkg/registry/common/authorize/ns_server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022-2023 Cisco and/or its affiliates.
// Copyright (c) 2022-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -72,8 +72,12 @@ func (s *authorizeNSServer) Register(ctx context.Context, ns *registry.NetworkSe
return nil, err
}

ns, err := next.NetworkServiceRegistryServer(ctx).Register(ctx, ns)
if err != nil {
return nil, err
}
s.nsPathIdsMap.Store(ns.Name, ns.PathIds)
return next.NetworkServiceRegistryServer(ctx).Register(ctx, ns)
return ns, nil
}

func (s *authorizeNSServer) Find(query *registry.NetworkServiceQuery, server registry.NetworkServiceRegistry_FindServer) error {
Expand Down
88 changes: 87 additions & 1 deletion pkg/registry/common/authorize/ns_server_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022 Cisco and/or its affiliates.
// Copyright (c) 2022-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand All @@ -18,13 +18,20 @@ package authorize_test

import (
"context"
"math/rand"
"net/url"
"testing"

"github.com/edwarnicke/genericsync"
"github.com/golang/protobuf/ptypes/empty"
"github.com/networkservicemesh/api/pkg/api/registry"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"

"github.com/networkservicemesh/sdk/pkg/registry/common/authorize"
"github.com/networkservicemesh/sdk/pkg/registry/common/grpcmetadata"
"github.com/networkservicemesh/sdk/pkg/registry/core/next"
"github.com/networkservicemesh/sdk/pkg/tools/nanoid"

"go.uber.org/goleak"
)
Expand Down Expand Up @@ -62,3 +69,82 @@ func TestNetworkServiceRegistryAuthorization(t *testing.T) {
_, err = server.Unregister(ctx1, ns)
require.NoError(t, err)
}

type randomErrorNSServer struct {
errorChance float32
}

func NewNetworkServiceRegistryServer(errorChance float32) registry.NetworkServiceRegistryServer {
return &randomErrorNSServer{
errorChance: errorChance,
}
}

func (s *randomErrorNSServer) Register(ctx context.Context, ns *registry.NetworkService) (*registry.NetworkService, error) {
// nolint
val := rand.Float32()
if val > s.errorChance {
return nil, errors.New("random error")
}
return next.NetworkServiceRegistryServer(ctx).Register(ctx, ns)
}

func (s *randomErrorNSServer) Find(query *registry.NetworkServiceQuery, server registry.NetworkServiceRegistry_FindServer) error {
return next.NetworkServiceRegistryServer(server.Context()).Find(query, server)
}

func (s *randomErrorNSServer) Unregister(ctx context.Context, ns *registry.NetworkService) (*empty.Empty, error) {
return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, ns)
}

type closeNSData struct {
ns *registry.NetworkService
path *grpcmetadata.Path
}

func TestNetworkServiceRegistryAuthorize_ResourcePathIdMapHaveNoLeaks(t *testing.T) {
var authorizeMap genericsync.Map[string, []string]
server := next.NewNetworkServiceRegistryServer(
authorize.NewNetworkServiceRegistryServer(
authorize.WithPolicies("etc/nsm/opa/registry/client_allowed.rego"),
authorize.WithResourcePathIdsMap(&authorizeMap)),
NewNetworkServiceRegistryServer(0.5),
)

// Make 1000 requests with random spiffe IDs
count := 1000
data := make([]closeNSData, 0)
for i := 0; i < count; i++ {
nsName, err := nanoid.GenerateString(10, nanoid.WithAlphabet("abcdefghijklmnopqrstuvwxyz"))
require.NoError(t, err)
spiffeidPath, err := nanoid.GenerateString(10, nanoid.WithAlphabet("abcdefghijklmnopqrstuvwxyz"))
require.NoError(t, err)

u := &url.URL{Scheme: "spiffe", Host: "test.com", Path: spiffeidPath}
spiffeid := u.String()

ns := &registry.NetworkService{Name: nsName}
ns.PathIds = []string{spiffeid}

path := getPath(t, spiffeid)
ctx := grpcmetadata.PathWithContext(context.Background(), path)

ns, err = server.Register(ctx, ns)
if err == nil {
data = append(data, closeNSData{ns: ns, path: path})
}
}

// Close the connections established in the previous loop
for _, closeData := range data {
ctx := grpcmetadata.PathWithContext(context.Background(), closeData.path)
_, err := server.Unregister(ctx, closeData.ns)
require.NoError(t, err)
}
mapLen := 0
authorizeMap.Range(func(key string, value []string) bool {
mapLen++
return true
})
require.Equal(t, mapLen, 0)
}
8 changes: 6 additions & 2 deletions pkg/registry/common/authorize/nse_server.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022-2023 Cisco and/or its affiliates.
// Copyright (c) 2022-2024 Cisco and/or its affiliates.
//
// SPDX-License-Identifier: Apache-2.0
//
Expand Down Expand Up @@ -73,8 +73,12 @@ func (s *authorizeNSEServer) Register(ctx context.Context, nse *registry.Network
return nil, err
}

nse, err := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse)
if err != nil {
return nil, err
}
s.nsePathIdsMap.Store(nse.Name, nse.PathIds)
return next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse)
return nse, nil
}

func (s *authorizeNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error {
Expand Down
Loading

0 comments on commit 3b79590

Please sign in to comment.