diff --git a/driver.go b/driver.go index 542844529..488c78525 100644 --- a/driver.go +++ b/driver.go @@ -20,7 +20,6 @@ import ( internalDiscovery "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery" discoveryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/discovery/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" internalQuery "github.com/ydb-platform/ydb-go-sdk/v3/internal/query" queryConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/config" internalRatelimiter "github.com/ydb-platform/ydb-go-sdk/v3/internal/ratelimiter" @@ -488,7 +487,7 @@ func (d *Driver) connect(ctx context.Context) (err error) { d.discovery = xsync.OnceValue(func() (*internalDiscovery.Client, error) { return internalDiscovery.New(xcontext.ValueOnly(ctx), - d.pool.Get(endpoint.New(d.config.Endpoint())), + d.balancer, discoveryConfig.New( append( // prepend common params from root config diff --git a/internal/balancer/balancer.go b/internal/balancer/balancer.go index cbb88135c..567024ee1 100644 --- a/internal/balancer/balancer.go +++ b/internal/balancer/balancer.go @@ -9,6 +9,7 @@ import ( "google.golang.org/grpc" "github.com/ydb-platform/ydb-go-sdk/v3/config" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/cluster" balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" @@ -26,8 +27,6 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) -var ErrNoEndpoints = xerrors.Wrap(fmt.Errorf("no endpoints")) - type discoveryClient interface { closer.Closer @@ -40,9 +39,12 @@ type Balancer struct { pool *conn.Pool discoveryClient discoveryClient discoveryRepeater repeater.Repeater - localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) - connectionsState atomic.Pointer[connectionsState] + cluster atomic.Pointer[cluster.Cluster] + conns xsync.Map[endpoint.Endpoint, conn.Conn] + banned xsync.Set[endpoint.Endpoint] + + localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) mu xsync.RWMutex onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info) @@ -124,19 +126,49 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) { } func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoint.Endpoint, localDC string) { - var ( - onDone = trace.DriverOnBalancerUpdate( - b.driverConfig.Trace(), &ctx, - stack.FunctionID( - "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"), - b.config.DetectLocalDC, - ) - previous = b.connections().All() + onDone := trace.DriverOnBalancerUpdate( + b.driverConfig.Trace(), &ctx, + stack.FunctionID( + "github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"), + b.config.DetectLocalDC, ) + + state := cluster.New(newest, + cluster.From(b.cluster.Load()), + cluster.WithFallback(b.config.AllowFallback), + cluster.WithFilter(func(e endpoint.Info) bool { + if b.config.Filter == nil { + return true + } + + return b.config.Filter.Allow(balancerConfig.Info{SelfLocation: localDC}, e) + }), + ) + + previous := b.cluster.Swap(state) + + _, added, dropped := xslices.Diff(previous.All(), newest, func(lhs, rhs endpoint.Endpoint) int { + return strings.Compare(lhs.Address(), rhs.Address()) + }) + + for _, e := range dropped { + c, ok := b.conns.Extract(e) + if !ok { + panic("wrong balancer state") + } + b.pool.Put(ctx, c) + } + + for _, e := range added { + cc, err := b.pool.Get(ctx, e) + if err != nil { + b.banned.Add(e) + } else { + b.conns.Set(e, cc) + } + } + defer func() { - _, added, dropped := xslices.Diff(previous, newest, func(lhs, rhs endpoint.Endpoint) int { - return strings.Compare(lhs.Address(), rhs.Address()) - }) onDone( xslices.Transform(newest, func(t endpoint.Endpoint) trace.EndpointInfo { return t }), xslices.Transform(added, func(t endpoint.Endpoint) trace.EndpointInfo { return t }), @@ -145,25 +177,13 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi ) }() - connections := endpointsToConnections(b.pool, newest) - for _, c := range connections { - b.pool.Allow(ctx, c) - c.Endpoint().Touch() - } - - info := balancerConfig.Info{SelfLocation: localDC} - state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback) - - endpointsInfo := make([]endpoint.Info, len(newest)) - for i, e := range newest { - endpointsInfo[i] = e - } - - b.connectionsState.Store(state) + endpoints := xslices.Transform(newest, func(e endpoint.Endpoint) endpoint.Info { + return e + }) b.mu.WithLock(func() { for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints { - onApplyDiscoveredEndpoints(ctx, endpointsInfo) + onApplyDiscoveredEndpoints(ctx, endpoints) } }) } @@ -212,18 +232,20 @@ func New( onDone(finalErr) }() + cc, err := pool.Get(ctx, endpoint.New(driverConfig.Endpoint())) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + b = &Balancer{ - driverConfig: driverConfig, - pool: pool, - discoveryClient: internalDiscovery.New(ctx, pool.Get( - endpoint.New(driverConfig.Endpoint()), - ), discoveryConfig), + config: balancerConfig.Config{}, + driverConfig: driverConfig, + pool: pool, + discoveryClient: internalDiscovery.New(ctx, cc, discoveryConfig), localDCDetector: detectLocalDC, } - if config := driverConfig.Balancer(); config == nil { - b.config = balancerConfig.Config{} - } else { + if config := driverConfig.Balancer(); config != nil { b.config = *config } @@ -289,10 +311,10 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc defer func() { if err == nil { if cc.GetState() == conn.Banned { - b.pool.Allow(ctx, cc) + b.banned.Remove(cc.Endpoint()) } } else if conn.IsBadConn(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) { - b.pool.Ban(ctx, cc, err) + b.banned.Add(cc.Endpoint()) } }() @@ -319,53 +341,45 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc return nil } -func (b *Balancer) connections() *connectionsState { - return b.connectionsState.Load() -} - -func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) { +func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, finalErr error) { onDone := trace.DriverOnBalancerChooseEndpoint( b.driverConfig.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).getConn"), ) + defer func() { - if err == nil { + if finalErr == nil { onDone(c.Endpoint(), nil) } else { - onDone(nil, err) + if b.cluster.Load().Availability() < 0.5 && b.discoveryRepeater != nil { + b.discoveryRepeater.Force() + } + + onDone(nil, finalErr) } }() - if err = ctx.Err(); err != nil { - return nil, xerrors.WithStackTrace(err) - } + for attempts := 1; ; attempts++ { + if err := ctx.Err(); err != nil { + return nil, xerrors.WithStackTrace(err) + } - var ( - state = b.connections() - failedCount int - ) + state := b.cluster.Load() - defer func() { - if failedCount*2 > state.PreferredCount() && b.discoveryRepeater != nil { - b.discoveryRepeater.Force() + e, err := state.Next(ctx) + if err != nil { + return nil, xerrors.WithStackTrace( + fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", cluster.ErrNoEndpoints, attempts), + ) } - }() - - c, failedCount = state.GetConnection(ctx) - if c == nil { - return nil, xerrors.WithStackTrace( - fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount), - ) - } - return c, nil -} + cc, err := b.pool.Get(ctx, e) + if err == nil { + return cc, nil + } -func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn { - conns := make([]conn.Conn, 0, len(endpoints)) - for _, e := range endpoints { - conns = append(conns, p.Get(e)) + if b.cluster.CompareAndSwap(state, cluster.Without(b.cluster.Load(), e)) { + b.banned.Add(e) + } } - - return conns } diff --git a/internal/balancer/cluster/cluster.go b/internal/balancer/cluster/cluster.go index d8ef64214..4029863dc 100644 --- a/internal/balancer/cluster/cluster.go +++ b/internal/balancer/cluster/cluster.go @@ -37,62 +37,92 @@ func WithFallback(allowFallback bool) option { } } +func From(parent *Cluster) option { + if parent == nil { + return nil + } + + return func(s *Cluster) { + s.rand = parent.rand + s.filter = parent.filter + s.allowFallback = parent.allowFallback + } +} + func New(endpoints []endpoint.Endpoint, opts ...option) *Cluster { - s := &Cluster{ + c := &Cluster{ filter: func(e endpoint.Info) bool { return true }, } for _, opt := range opts { - opt(s) + if opt != nil { + opt(c) + } } - if s.rand == nil { - s.rand = xrand.New(xrand.WithLock()) + if c.rand == nil { + c.rand = xrand.New(xrand.WithLock()) } - s.prefer, s.fallback = xslices.Split(endpoints, func(e endpoint.Endpoint) bool { - return s.filter(e) + c.prefer, c.fallback = xslices.Split(endpoints, func(e endpoint.Endpoint) bool { + return c.filter(e) }) - if s.allowFallback { - s.all = endpoints - s.index = xslices.Map(endpoints, func(e endpoint.Endpoint) uint32 { return e.NodeID() }) + if c.allowFallback { + c.all = endpoints + c.index = xslices.Map(endpoints, func(e endpoint.Endpoint) uint32 { return e.NodeID() }) } else { - s.all = s.prefer - s.fallback = nil - s.index = xslices.Map(s.prefer, func(e endpoint.Endpoint) uint32 { return e.NodeID() }) + c.all = c.prefer + c.fallback = nil + c.index = xslices.Map(c.prefer, func(e endpoint.Endpoint) uint32 { return e.NodeID() }) } - return s + return c } -func (s *Cluster) All() (all []endpoint.Endpoint) { - if s == nil { +func (c *Cluster) All() (all []endpoint.Endpoint) { + if c == nil { return nil } - return s.all + return c.all } -func Without(s *Cluster, endpoints ...endpoint.Endpoint) *Cluster { - prefer := make([]endpoint.Endpoint, 0, len(s.prefer)) - fallback := s.fallback - for _, endpoint := range endpoints { +func (c *Cluster) Availability() (percent float64) { + return float64(len(c.prefer)+len(c.fallback)) / float64(len(c.all)) +} + +func Without(s *Cluster, required endpoint.Endpoint, other ...endpoint.Endpoint) *Cluster { + var ( + prefer = make([]endpoint.Endpoint, 0, len(s.prefer)) + fallback = make([]endpoint.Endpoint, 0, len(s.fallback)) + ) + + for _, endpoint := range append(other, required) { for i := range s.prefer { if s.prefer[i].Address() != endpoint.Address() { prefer = append(prefer, s.prefer[i]) - } else { - fallback = append(fallback, s.prefer[i]) } } + for i := range s.fallback { + if s.fallback[i].Address() != endpoint.Address() { + fallback = append(fallback, s.fallback[i]) + } + } + } + + if len(prefer)+len(fallback) == len(s.prefer)+len(s.fallback) { + return s } + all := append(append(make([]endpoint.Endpoint, 0, len(prefer)+len(fallback)), prefer...), fallback...) + return &Cluster{ filter: s.filter, allowFallback: s.allowFallback, - index: s.index, + index: xslices.Map(all, func(e endpoint.Endpoint) uint32 { return e.NodeID() }), prefer: prefer, fallback: fallback, all: s.all, @@ -100,8 +130,8 @@ func Without(s *Cluster, endpoints ...endpoint.Endpoint) *Cluster { } } -func (s *Cluster) Next(ctx context.Context) (endpoint.Endpoint, error) { - if s == nil { +func (c *Cluster) Next(ctx context.Context) (endpoint.Endpoint, error) { + if c == nil { return nil, ErrNilPtr } @@ -110,18 +140,18 @@ func (s *Cluster) Next(ctx context.Context) (endpoint.Endpoint, error) { } if nodeID, wantEndpointByNodeID := endpoint.ContextNodeID(ctx); wantEndpointByNodeID { - e, has := s.index[nodeID] + e, has := c.index[nodeID] if has { return e, nil } } - if l := len(s.prefer); l > 0 { - return s.prefer[s.rand.Int(l)], nil + if l := len(c.prefer); l > 0 { + return c.prefer[c.rand.Int(l)], nil } - if l := len(s.fallback); l > 0 { - return s.fallback[s.rand.Int(l)], nil + if l := len(c.fallback); l > 0 { + return c.fallback[c.rand.Int(l)], nil } return nil, xerrors.WithStackTrace(ErrNoEndpoints) diff --git a/internal/balancer/cluster/cluster_test.go b/internal/balancer/cluster/cluster_test.go index 828ae5349..1a36d51cd 100644 --- a/internal/balancer/cluster/cluster_test.go +++ b/internal/balancer/cluster/cluster_test.go @@ -2,7 +2,6 @@ package cluster import ( "context" - "math" "strconv" "sync" "testing" @@ -90,73 +89,82 @@ func TestCluster(t *testing.T) { AddrField: "5", NodeIDField: 5, }, - }) + }, WithFilter(func(e endpoint.Info) bool { + return e.NodeID()%2 == 0 + }), WithFallback(true)) { // initial state require.Len(t, s.All(), 5) + require.InEpsilon(t, 1.0, s.Availability(), 0.001) require.Len(t, s.index, 5) - require.Len(t, s.prefer, 5) + require.Len(t, s.prefer, 2) + require.Len(t, s.fallback, 3) } - { // without first endpoint + { // without first endpoint (excluded from prefer) e, err := s.Next(ctx) require.NoError(t, err) require.NotNil(t, e) s = Without(s, e) require.Len(t, s.All(), 5) - require.Len(t, s.index, 5) - require.Len(t, s.prefer, 4) - require.Len(t, s.fallback, 1) + require.InEpsilon(t, 4.0/5.0, s.Availability(), 0.001) + require.Len(t, s.index, 4) + require.Len(t, s.prefer, 1) + require.Len(t, s.fallback, 3) } - { // without second endpoint + { // without second endpoint (excluded from prefer) e, err := s.Next(ctx) require.NoError(t, err) require.NotNil(t, e) s = Without(s, e) require.Len(t, s.All(), 5) - require.Len(t, s.index, 5) - require.Len(t, s.prefer, 3) - require.Len(t, s.fallback, 2) + require.InEpsilon(t, 3.0/5.0, s.Availability(), 0.001) + require.Len(t, s.index, 3) + require.Empty(t, s.prefer) + require.Len(t, s.fallback, 3) } - { // without third endpoint + { // without third endpoint (excluded from fallback) e, err := s.Next(ctx) require.NoError(t, err) require.NotNil(t, e) s = Without(s, e) require.Len(t, s.All(), 5) - require.Len(t, s.index, 5) - require.Len(t, s.prefer, 2) - require.Len(t, s.fallback, 3) + require.InEpsilon(t, 2.0/5.0, s.Availability(), 0.001) + require.Len(t, s.index, 2) + require.Empty(t, s.prefer) + require.Len(t, s.fallback, 2) } - { // without fourth endpoint + { // without fourth endpoint (excluded from fallback) e, err := s.Next(ctx) require.NoError(t, err) require.NotNil(t, e) s = Without(s, e) require.Len(t, s.All(), 5) - require.Len(t, s.index, 5) - require.Len(t, s.prefer, 1) - require.Len(t, s.fallback, 4) + require.InEpsilon(t, 1.0/5.0, s.Availability(), 0.001) + require.Len(t, s.index, 1) + require.Empty(t, s.prefer) + require.Len(t, s.fallback, 1) } - { // without fifth endpoint + { // without fifth endpoint (excluded from fallback) e, err := s.Next(ctx) require.NoError(t, err) require.NotNil(t, e) s = Without(s, e) require.Len(t, s.All(), 5) - require.Len(t, s.index, 5) + require.Zero(t, s.Availability()) + require.Empty(t, s.index) require.Empty(t, s.prefer) - require.Len(t, s.fallback, 5) + require.Empty(t, s.fallback) } - { // next from fallback is ok + { // empty prefer and fallback lists e, err := s.Next(ctx) - require.NoError(t, err) - require.NotNil(t, e) + require.ErrorIs(t, err, ErrNoEndpoints) + require.Nil(t, e) } }) @@ -272,7 +280,7 @@ func TestCluster(t *testing.T) { const ( buckets = 10 total = 1000000 - epsilon = int(float64(total) / float64(buckets) * 0.015) + epsilon = float64(total) / float64(buckets) * 0.0015 ) endpoints := make([]endpoint.Endpoint, buckets) @@ -294,11 +302,7 @@ func TestCluster(t *testing.T) { } for i := range distribution { - if distribution[i] < total/buckets-epsilon || distribution[i] > total/buckets+epsilon { - t.Errorf("unexpected distribuition[%d] = %0.1f%%", i, - math.Abs(float64(distribution[i]-total/buckets)/float64(total/buckets)*100), - ) - } + require.InEpsilon(t, total/buckets, distribution[i], epsilon) } }) } diff --git a/internal/balancer/connections_state.go b/internal/balancer/connections_state.go deleted file mode 100644 index fecbc6db1..000000000 --- a/internal/balancer/connections_state.go +++ /dev/null @@ -1,179 +0,0 @@ -package balancer - -import ( - "context" - - balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xrand" -) - -type connectionsState struct { - connByNodeID map[uint32]conn.Conn - - prefer []conn.Conn - fallback []conn.Conn - all []conn.Conn - - rand xrand.Rand -} - -func newConnectionsState( - conns []conn.Conn, - filter balancerConfig.Filter, - info balancerConfig.Info, - allowFallback bool, -) *connectionsState { - res := &connectionsState{ - connByNodeID: connsToNodeIDMap(conns), - rand: xrand.New(xrand.WithLock()), - } - - res.prefer, res.fallback = sortPreferConnections(conns, filter, info, allowFallback) - if allowFallback { - res.all = conns - } else { - res.all = res.prefer - } - - return res -} - -func (s *connectionsState) PreferredCount() int { - return len(s.prefer) -} - -func (s *connectionsState) All() (all []endpoint.Endpoint) { - if s == nil { - return nil - } - - all = make([]endpoint.Endpoint, len(s.all)) - for i, c := range s.all { - all[i] = c.Endpoint() - } - - return all -} - -func (s *connectionsState) GetConnection(ctx context.Context) (_ conn.Conn, failedCount int) { - if err := ctx.Err(); err != nil { - return nil, 0 - } - - if c := s.preferConnection(ctx); c != nil { - return c, 0 - } - - try := func(conns []conn.Conn) conn.Conn { - c, tryFailed := s.selectRandomConnection(conns, false) - failedCount += tryFailed - - return c - } - - if c := try(s.prefer); c != nil { - return c, failedCount - } - - if c := try(s.fallback); c != nil { - return c, failedCount - } - - c, _ := s.selectRandomConnection(s.all, true) - - return c, failedCount -} - -func (s *connectionsState) preferConnection(ctx context.Context) conn.Conn { - if nodeID, hasPreferEndpoint := endpoint.ContextNodeID(ctx); hasPreferEndpoint { - c := s.connByNodeID[nodeID] - if c != nil && isOkConnection(c, true) { - return c - } - } - - return nil -} - -func (s *connectionsState) selectRandomConnection(conns []conn.Conn, allowBanned bool) (c conn.Conn, failedConns int) { - connCount := len(conns) - if connCount == 0 { - // return for empty list need for prevent panic in fast path - return nil, 0 - } - - // fast path - if c := conns[s.rand.Int(connCount)]; isOkConnection(c, allowBanned) { - return c, 0 - } - - // shuffled indexes slices need for guarantee about every connection will check - indexes := make([]int, connCount) - for index := range indexes { - indexes[index] = index - } - s.rand.Shuffle(connCount, func(i, j int) { - indexes[i], indexes[j] = indexes[j], indexes[i] - }) - - for _, index := range indexes { - c := conns[index] - if isOkConnection(c, allowBanned) { - return c, 0 - } - failedConns++ - } - - return nil, failedConns -} - -func connsToNodeIDMap(conns []conn.Conn) (nodes map[uint32]conn.Conn) { - if len(conns) == 0 { - return nil - } - nodes = make(map[uint32]conn.Conn, len(conns)) - for _, c := range conns { - nodes[c.Endpoint().NodeID()] = c - } - - return nodes -} - -func sortPreferConnections( - conns []conn.Conn, - filter balancerConfig.Filter, - info balancerConfig.Info, - allowFallback bool, -) (prefer, fallback []conn.Conn) { - if filter == nil { - return conns, nil - } - - prefer = make([]conn.Conn, 0, len(conns)) - if allowFallback { - fallback = make([]conn.Conn, 0, len(conns)) - } - - for _, c := range conns { - if filter.Allow(info, c.Endpoint()) { - prefer = append(prefer, c) - } else if allowFallback { - fallback = append(fallback, c) - } - } - - return prefer, fallback -} - -func isOkConnection(c conn.Conn, bannedIsOk bool) bool { - switch c.GetState() { - case conn.Online, conn.Created, conn.Offline: - return true - case conn.Banned: - return bannedIsOk - default: - return false - } -} diff --git a/internal/balancer/connections_state_test.go b/internal/balancer/connections_state_test.go deleted file mode 100644 index c8648ee2a..000000000 --- a/internal/balancer/connections_state_test.go +++ /dev/null @@ -1,464 +0,0 @@ -package balancer - -import ( - "context" - "strings" - "testing" - - "github.com/stretchr/testify/require" - - balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/conn" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/mock" -) - -func TestConnsToNodeIDMap(t *testing.T) { - table := []struct { - name string - source []conn.Conn - res map[uint32]conn.Conn - }{ - { - name: "Empty", - source: nil, - res: nil, - }, - { - name: "Zero", - source: []conn.Conn{ - &mock.Conn{NodeIDField: 0}, - }, - res: map[uint32]conn.Conn{ - 0: &mock.Conn{NodeIDField: 0}, - }, - }, - { - name: "NonZero", - source: []conn.Conn{ - &mock.Conn{NodeIDField: 1}, - &mock.Conn{NodeIDField: 10}, - }, - res: map[uint32]conn.Conn{ - 1: &mock.Conn{NodeIDField: 1}, - 10: &mock.Conn{NodeIDField: 10}, - }, - }, - { - name: "Combined", - source: []conn.Conn{ - &mock.Conn{NodeIDField: 1}, - &mock.Conn{NodeIDField: 0}, - &mock.Conn{NodeIDField: 10}, - }, - res: map[uint32]conn.Conn{ - 0: &mock.Conn{NodeIDField: 0}, - 1: &mock.Conn{NodeIDField: 1}, - 10: &mock.Conn{NodeIDField: 10}, - }, - }, - } - - for _, test := range table { - t.Run(test.name, func(t *testing.T) { - require.Equal(t, test.res, connsToNodeIDMap(test.source)) - }) - } -} - -type filterFunc func(info balancerConfig.Info, e endpoint.Info) bool - -func (f filterFunc) Allow(info balancerConfig.Info, e endpoint.Info) bool { - return f(info, e) -} - -func (f filterFunc) String() string { - return "Custom" -} - -func TestSortPreferConnections(t *testing.T) { - table := []struct { - name string - source []conn.Conn - allowFallback bool - filter balancerConfig.Filter - prefer []conn.Conn - fallback []conn.Conn - }{ - { - name: "Empty", - source: nil, - allowFallback: false, - filter: nil, - prefer: nil, - fallback: nil, - }, - { - name: "NilFilter", - source: []conn.Conn{ - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "2"}, - }, - allowFallback: false, - filter: nil, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "1"}, - &mock.Conn{AddrField: "2"}, - }, - fallback: nil, - }, - { - name: "FilterNoFallback", - source: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "f1"}, - &mock.Conn{AddrField: "t2"}, - &mock.Conn{AddrField: "f2"}, - }, - allowFallback: false, - filter: filterFunc(func(_ balancerConfig.Info, e endpoint.Info) bool { - return strings.HasPrefix(e.Address(), "t") - }), - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "t2"}, - }, - fallback: nil, - }, - { - name: "FilterWithFallback", - source: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "f1"}, - &mock.Conn{AddrField: "t2"}, - &mock.Conn{AddrField: "f2"}, - }, - allowFallback: true, - filter: filterFunc(func(_ balancerConfig.Info, e endpoint.Info) bool { - return strings.HasPrefix(e.Address(), "t") - }), - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1"}, - &mock.Conn{AddrField: "t2"}, - }, - fallback: []conn.Conn{ - &mock.Conn{AddrField: "f1"}, - &mock.Conn{AddrField: "f2"}, - }, - }, - } - - for _, test := range table { - t.Run(test.name, func(t *testing.T) { - prefer, fallback := sortPreferConnections(test.source, test.filter, balancerConfig.Info{}, test.allowFallback) - require.Equal(t, test.prefer, prefer) - require.Equal(t, test.fallback, fallback) - }) - } -} - -func TestSelectRandomConnection(t *testing.T) { - s := newConnectionsState(nil, nil, balancerConfig.Info{}, false) - - t.Run("Empty", func(t *testing.T) { - c, failedCount := s.selectRandomConnection(nil, false) - require.Nil(t, c) - require.Equal(t, 0, failedCount) - }) - - t.Run("One", func(t *testing.T) { - for _, goodState := range []conn.State{conn.Online, conn.Offline, conn.Created} { - c, failedCount := s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: goodState}}, false) - require.Equal(t, &mock.Conn{AddrField: "asd", State: goodState}, c) - require.Equal(t, 0, failedCount) - } - }) - t.Run("OneBanned", func(t *testing.T) { - c, failedCount := s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: conn.Banned}}, false) - require.Nil(t, c) - require.Equal(t, 1, failedCount) - - c, failedCount = s.selectRandomConnection([]conn.Conn{&mock.Conn{AddrField: "asd", State: conn.Banned}}, true) - require.Equal(t, &mock.Conn{AddrField: "asd", State: conn.Banned}, c) - require.Equal(t, 0, failedCount) - }) - t.Run("Two", func(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online}, - } - first := 0 - second := 0 - for i := 0; i < 100; i++ { - c, _ := s.selectRandomConnection(conns, false) - if c.Endpoint().Address() == "1" { - first++ - } else { - second++ - } - } - require.Equal(t, 100, first+second) - require.InDelta(t, 50, first, 21) - require.InDelta(t, 50, second, 21) - }) - t.Run("TwoBanned", func(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Banned}, - &mock.Conn{AddrField: "2", State: conn.Banned}, - } - totalFailed := 0 - for i := 0; i < 100; i++ { - c, failed := s.selectRandomConnection(conns, false) - require.Nil(t, c) - totalFailed += failed - } - require.Equal(t, 200, totalFailed) - }) - t.Run("ThreeWithBanned", func(t *testing.T) { - conns := []conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online}, - &mock.Conn{AddrField: "3", State: conn.Banned}, - } - first := 0 - second := 0 - failed := 0 - for i := 0; i < 100; i++ { - c, checkFailed := s.selectRandomConnection(conns, false) - failed += checkFailed - switch c.Endpoint().Address() { - case "1": - first++ - case "2": - second++ - default: - t.Errorf(c.Endpoint().Address()) - } - } - require.Equal(t, 100, first+second) - require.InDelta(t, 50, first, 21) - require.InDelta(t, 50, second, 21) - require.Greater(t, 10, failed) - }) -} - -func TestNewState(t *testing.T) { - table := []struct { - name string - state *connectionsState - res *connectionsState - }{ - { - name: "Empty", - state: newConnectionsState(nil, nil, balancerConfig.Info{}, false), - res: &connectionsState{ - connByNodeID: nil, - prefer: nil, - fallback: nil, - all: nil, - }, - }, - { - name: "NoFilter", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", NodeIDField: 1}, - &mock.Conn{AddrField: "2", NodeIDField: 2}, - }, nil, balancerConfig.Info{}, false), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "1", NodeIDField: 1}, - 2: &mock.Conn{AddrField: "2", NodeIDField: 2}, - }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "1", NodeIDField: 1}, - &mock.Conn{AddrField: "2", NodeIDField: 2}, - }, - fallback: nil, - all: []conn.Conn{ - &mock.Conn{AddrField: "1", NodeIDField: 1}, - &mock.Conn{AddrField: "2", NodeIDField: 2}, - }, - }, - }, - { - name: "FilterDenyFallback", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { - return info.SelfLocation == e.Location() - }), balancerConfig.Info{SelfLocation: "t"}, false), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - 4: &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - }, - fallback: nil, - all: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - }, - }, - }, - { - name: "FilterAllowFallback", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { - return info.SelfLocation == e.Location() - }), balancerConfig.Info{SelfLocation: "t"}, true), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - 4: &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - }, - fallback: []conn.Conn{ - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - all: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - }, - }, - { - name: "WithNodeID", - state: newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { - return info.SelfLocation == e.Location() - }), balancerConfig.Info{SelfLocation: "t"}, true), - res: &connectionsState{ - connByNodeID: map[uint32]conn.Conn{ - 1: &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - 2: &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - 3: &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - 4: &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - prefer: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - }, - fallback: []conn.Conn{ - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - all: []conn.Conn{ - &mock.Conn{AddrField: "t1", NodeIDField: 1, LocationField: "t"}, - &mock.Conn{AddrField: "f1", NodeIDField: 2, LocationField: "f"}, - &mock.Conn{AddrField: "t2", NodeIDField: 3, LocationField: "t"}, - &mock.Conn{AddrField: "f2", NodeIDField: 4, LocationField: "f"}, - }, - }, - }, - } - - for _, test := range table { - t.Run(test.name, func(t *testing.T) { - require.NotNil(t, test.state.rand) - test.state.rand = nil - require.Equal(t, test.res, test.state) - }) - } -} - -func TestConnection(t *testing.T) { - t.Run("Empty", func(t *testing.T) { - s := newConnectionsState(nil, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(context.Background()) - require.Nil(t, c) - require.Equal(t, 0, failed) - }) - t.Run("AllGood", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Online}, - }, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(context.Background()) - require.NotNil(t, c) - require.Equal(t, 0, failed) - }) - t.Run("WithBanned", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online}, - &mock.Conn{AddrField: "2", State: conn.Banned}, - }, nil, balancerConfig.Info{}, false) - c, _ := s.GetConnection(context.Background()) - require.Equal(t, &mock.Conn{AddrField: "1", State: conn.Online}, c) - }) - t.Run("AllBanned", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", State: conn.Banned, LocationField: "t"}, - &mock.Conn{AddrField: "f2", State: conn.Banned, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { - return e.Location() == info.SelfLocation - }), balancerConfig.Info{}, true) - preferred := 0 - fallback := 0 - for i := 0; i < 100; i++ { - c, failed := s.GetConnection(context.Background()) - require.NotNil(t, c) - require.Equal(t, 2, failed) - if c.Endpoint().Address() == "t1" { - preferred++ - } else { - fallback++ - } - } - require.Equal(t, 100, preferred+fallback) - require.InDelta(t, 50, preferred, 21) - require.InDelta(t, 50, fallback, 21) - }) - t.Run("PreferBannedWithFallback", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "t1", State: conn.Banned, LocationField: "t"}, - &mock.Conn{AddrField: "f2", State: conn.Online, LocationField: "f"}, - }, filterFunc(func(info balancerConfig.Info, e endpoint.Info) bool { - return e.Location() == info.SelfLocation - }), balancerConfig.Info{SelfLocation: "t"}, true) - c, failed := s.GetConnection(context.Background()) - require.Equal(t, &mock.Conn{AddrField: "f2", State: conn.Online, LocationField: "f"}, c) - require.Equal(t, 1, failed) - }) - t.Run("PreferNodeID", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, - &mock.Conn{AddrField: "2", State: conn.Online, NodeIDField: 2}, - }, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(endpoint.WithNodeID(context.Background(), 2)) - require.Equal(t, &mock.Conn{AddrField: "2", State: conn.Online, NodeIDField: 2}, c) - require.Equal(t, 0, failed) - }) - t.Run("PreferNodeIDWithBadState", func(t *testing.T) { - s := newConnectionsState([]conn.Conn{ - &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, - &mock.Conn{AddrField: "2", State: conn.Unknown, NodeIDField: 2}, - }, nil, balancerConfig.Info{}, false) - c, failed := s.GetConnection(endpoint.WithNodeID(context.Background(), 2)) - require.Equal(t, &mock.Conn{AddrField: "1", State: conn.Online, NodeIDField: 1}, c) - require.Equal(t, 0, failed) - }) -} diff --git a/internal/balancer/local_dc.go b/internal/balancer/local_dc.go index b1ee2e086..88803ccf7 100644 --- a/internal/balancer/local_dc.go +++ b/internal/balancer/local_dc.go @@ -10,6 +10,7 @@ import ( "strings" "sync" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/cluster" "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" @@ -114,7 +115,7 @@ func detectFastestEndpoint(ctx context.Context, endpoints []endpoint.Endpoint) ( func detectLocalDC(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) { if len(endpoints) == 0 { - return "", xerrors.WithStackTrace(ErrNoEndpoints) + return "", xerrors.WithStackTrace(cluster.ErrNoEndpoints) } endpointsByDc := splitEndpointsByLocation(endpoints) diff --git a/internal/balancer/local_dc_test.go b/internal/balancer/local_dc_test.go index 2eab1e9a8..1477a6d31 100644 --- a/internal/balancer/local_dc_test.go +++ b/internal/balancer/local_dc_test.go @@ -151,7 +151,7 @@ func TestLocalDCDiscovery(t *testing.T) { require.NoError(t, err) for i := 0; i < 100; i++ { - conn, _ := r.connections().GetConnection(ctx) + conn, _ := r.getConn(ctx) require.Equal(t, "b:234", conn.Endpoint().Address()) require.Equal(t, "b", conn.Endpoint().Location()) } diff --git a/internal/conn/conn.go b/internal/conn/conn.go index 067971fcf..231cca85a 100644 --- a/internal/conn/conn.go +++ b/internal/conn/conn.go @@ -58,8 +58,7 @@ type conn struct { state atomic.Uint32 childStreams *xcontext.CancelsGuard lastUsage xsync.LastUsage - onClose []func(*conn) - onTransportErrors []func(ctx context.Context, cc Conn, cause error) + onTransportErrors []func(err error) } func (c *conn) Address() string { @@ -220,7 +219,7 @@ func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) { } defer func() { - c.onTransportError(ctx, err) + c.onTransportError(err) }() return nil, xerrors.WithStackTrace( @@ -237,9 +236,9 @@ func (c *conn) realConn(ctx context.Context) (cc *grpc.ClientConn, err error) { return c.grpcConn, nil } -func (c *conn) onTransportError(ctx context.Context, cause error) { +func (c *conn) onTransportError(err error) { for _, onTransportError := range c.onTransportErrors { - onTransportError(ctx, c, cause) + onTransportError(err) } } @@ -291,9 +290,7 @@ func (c *conn) Close(ctx context.Context) (err error) { c.setState(ctx, Destroyed) - for _, onClose := range c.onClose { - onClose(c) - } + c.childStreams.Cancel() onDone(err) }() @@ -310,8 +307,6 @@ func (c *conn) Close(ctx context.Context) (err error) { )) } -var onTransportErrorStub = func(ctx context.Context, err error) {} - func replyWrapper(reply any) (opID string, issues []trace.Issue) { switch t := reply.(type) { case operation.Response: @@ -328,16 +323,23 @@ func replyWrapper(reply any) (opID string, issues []trace.Issue) { return opID, issues } +type ( + invokeSettings struct { + onTransportError func(error) + address string + nodeID uint32 + } + invokeOption func(s *invokeSettings) +) + //nolint:funlen func invoke( ctx context.Context, method string, req, reply any, cc grpc.ClientConnInterface, - onTransportError func(context.Context, error), - address string, - nodeID uint32, - opts ...grpc.CallOption, + grpcCallOpts []grpc.CallOption, + opts ...invokeOption, ) ( opID string, issues []trace.Issue, @@ -352,17 +354,24 @@ func invoke( ctx, sentMark := markContext(meta.WithTraceID(ctx, traceID)) - if onTransportError == nil { - onTransportError = onTransportErrorStub + s := invokeSettings{} + for _, opt := range opts { + if opt != nil { + opt(&s) + } } - err = cc.Invoke(ctx, method, req, reply, opts...) + err = cc.Invoke(ctx, method, req, reply, grpcCallOpts...) if err != nil { if xerrors.IsContextError(err) { return opID, issues, xerrors.WithStackTrace(err) } - defer onTransportError(ctx, err) + defer func() { + if s.onTransportError != nil { + s.onTransportError(err) + } + }() if !useWrapping { return opID, issues, err @@ -378,8 +387,8 @@ func invoke( } return opID, issues, xerrors.WithStackTrace(xerrors.Transport(err, - xerrors.WithAddress(address), - xerrors.WithNodeID(nodeID), + xerrors.WithAddress(s.address), + xerrors.WithNodeID(s.nodeID), xerrors.WithTraceID(traceID), )) } @@ -400,8 +409,8 @@ func invoke( return opID, issues, xerrors.WithStackTrace( xerrors.Operation( xerrors.FromOperation(t.GetOperation()), - xerrors.WithAddress(address), - xerrors.WithNodeID(nodeID), + xerrors.WithAddress(s.address), + xerrors.WithNodeID(s.nodeID), xerrors.WithTraceID(traceID), ), ) @@ -411,8 +420,8 @@ func invoke( return opID, issues, xerrors.WithStackTrace( xerrors.Operation( xerrors.FromOperation(t), - xerrors.WithAddress(address), - xerrors.WithNodeID(nodeID), + xerrors.WithAddress(s.address), + xerrors.WithNodeID(s.nodeID), xerrors.WithTraceID(traceID), ), ) @@ -453,16 +462,12 @@ func (c *conn) Invoke( stop := c.lastUsage.Start() defer stop() - opID, issues, err = invoke( - ctx, - method, - req, - res, - cc, - c.onTransportError, - c.Address(), - c.NodeID(), - append(opts, grpc.Trailer(&md))..., + opID, issues, err = invoke(ctx, method, req, res, cc, append(opts, grpc.Trailer(&md)), + func(s *invokeSettings) { + s.onTransportError = c.onTransportError + s.address = c.Address() + s.nodeID = c.NodeID() + }, ) return err @@ -526,7 +531,7 @@ func (c *conn) NewStream( } defer func() { - c.onTransportError(ctx, err) + c.onTransportError(err) }() if !useWrapping { @@ -553,15 +558,7 @@ func (c *conn) NewStream( type option func(c *conn) -func withOnClose(onClose func(*conn)) option { - return func(c *conn) { - if onClose != nil { - c.onClose = append(c.onClose, onClose) - } - } -} - -func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, cause error)) option { +func withOnTransportError(onTransportError func(err error)) option { return func(c *conn) { if onTransportError != nil { c.onTransportErrors = append(c.onTransportErrors, onTransportError) @@ -569,18 +566,13 @@ func withOnTransportError(onTransportError func(ctx context.Context, cc Conn, ca } } -func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn { +func dial(_ context.Context, e endpoint.Endpoint, config Config, opts ...option) *conn { c := &conn{ endpoint: e, config: config, done: make(chan struct{}), lastUsage: xsync.NewLastUsage(), childStreams: xcontext.NewCancelsGuard(), - onClose: []func(*conn){ - func(c *conn) { - c.childStreams.Cancel() - }, - }, } c.state.Store(uint32(Created)) for _, opt := range opts { @@ -592,10 +584,6 @@ func newConn(e endpoint.Endpoint, config Config, opts ...option) *conn { return c } -func New(e endpoint.Endpoint, config Config, opts ...option) Conn { - return newConn(e, config, opts...) -} - var _ stats.Handler = statsHandler{} type statsHandler struct{} diff --git a/internal/conn/conn_test.go b/internal/conn/conn_test.go index fb78ad212..62ce055ad 100644 --- a/internal/conn/conn_test.go +++ b/internal/conn/conn_test.go @@ -30,7 +30,7 @@ type connMock struct { } func (c connMock) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error { - _, _, err := invoke(ctx, method, args, reply, c.cc, nil, "", 0, opts...) + _, _, err := invoke(ctx, method, args, reply, c.cc, opts) return err } diff --git a/internal/conn/grpc_client_stream.go b/internal/conn/grpc_client_stream.go index 397d1e8d3..b39e63f23 100644 --- a/internal/conn/grpc_client_stream.go +++ b/internal/conn/grpc_client_stream.go @@ -95,7 +95,7 @@ func (s *grpcClientStream) SendMsg(m interface{}) (err error) { } defer func() { - s.parentConn.onTransportError(ctx, err) + s.parentConn.onTransportError(err) }() if !s.wrapping { @@ -155,7 +155,7 @@ func (s *grpcClientStream) RecvMsg(m interface{}) (err error) { //nolint:funlen } defer func() { - s.parentConn.onTransportError(ctx, err) + s.parentConn.onTransportError(err) }() if !s.wrapping { diff --git a/internal/conn/pool.go b/internal/conn/pool.go index 783b7a880..e61aeb4df 100644 --- a/internal/conn/pool.go +++ b/internal/conn/pool.go @@ -7,7 +7,6 @@ import ( "time" "google.golang.org/grpc" - grpcCodes "google.golang.org/grpc/codes" "github.com/ydb-platform/ydb-go-sdk/v3/internal/closer" "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" @@ -18,130 +17,59 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) -type connsKey struct { - address string - nodeID uint32 +type connWithCounter struct { + counter atomic.Int32 + errs atomic.Uint32 + conn *conn } type Pool struct { usages int64 config Config - mtx xsync.RWMutex opts []grpc.DialOption - conns map[connsKey]*conn + conns xsync.Map[endpoint.Endpoint, *connWithCounter] done chan struct{} } -func (p *Pool) Get(endpoint endpoint.Endpoint) Conn { - p.mtx.Lock() - defer p.mtx.Unlock() - - var ( - address = endpoint.Address() - cc *conn - has bool - ) - - key := connsKey{address, endpoint.NodeID()} - - if cc, has = p.conns[key]; has { - return cc - } - - cc = newConn( - endpoint, - p.config, - withOnClose(p.remove), - withOnTransportError(p.Ban), - ) - - p.conns[key] = cc - - return cc -} - -func (p *Pool) remove(c *conn) { - p.mtx.Lock() - defer p.mtx.Unlock() - delete(p.conns, connsKey{c.Endpoint().Address(), c.Endpoint().NodeID()}) -} +func (p *Pool) Put(ctx context.Context, c Conn) bool { + if cc, has := p.conns.Get(c.Endpoint()); has { + if cc.counter.Add(-1) == 0 { + cc.conn.Close(ctx) + p.conns.Delete(c.Endpoint()) + } -func (p *Pool) isClosed() bool { - select { - case <-p.done: return true - default: - return false - } -} - -func (p *Pool) Ban(ctx context.Context, cc Conn, cause error) { - if p.isClosed() { - return } - if !xerrors.IsTransportError(cause, - grpcCodes.ResourceExhausted, - grpcCodes.Unavailable, - // grpcCodes.OK, - // grpcCodes.Canceled, - // grpcCodes.Unknown, - // grpcCodes.InvalidArgument, - // grpcCodes.DeadlineExceeded, - // grpcCodes.NotFound, - // grpcCodes.AlreadyExists, - // grpcCodes.PermissionDenied, - // grpcCodes.FailedPrecondition, - // grpcCodes.Aborted, - // grpcCodes.OutOfRange, - // grpcCodes.Unimplemented, - // grpcCodes.Internal, - // grpcCodes.DataLoss, - // grpcCodes.Unauthenticated, - ) { - return - } - - e := cc.Endpoint().Copy() + return false +} - p.mtx.RLock() - defer p.mtx.RUnlock() +func (p *Pool) Get(ctx context.Context, e endpoint.Endpoint) (Conn, error) { + if cc, has := p.conns.Get(e); has { + cc.counter.Add(1) - cc, ok := p.conns[connsKey{e.Address(), e.NodeID()}] - if !ok { - return + return cc.conn, nil } - trace.DriverOnConnBan( - p.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Ban"), - e, cc.GetState(), cause, - )(cc.SetState(ctx, Banned)) -} - -func (p *Pool) Allow(ctx context.Context, cc Conn) { - if p.isClosed() { - return - } + cc := &connWithCounter{} - e := cc.Endpoint().Copy() + cc.conn = dial(ctx, e, + p.config, + withOnTransportError(func(err error) { + if IsBadConn(err) { + cc.errs.Add(1) + } + }), + ) - p.mtx.RLock() - defer p.mtx.RUnlock() + cc.counter.Add(1) - cc, ok := p.conns[connsKey{e.Address(), e.NodeID()}] - if !ok { - return - } + p.conns.Set(e, cc) - trace.DriverOnConnAllow( - p.config.Trace(), &ctx, - stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/conn.(*Pool).Allow"), - e, cc.GetState(), - )(cc.Unban(ctx)) + return cc.conn, nil } -func (p *Pool) Take(context.Context) error { +func (p *Pool) Take(_ context.Context) error { atomic.AddInt64(&p.usages, 1) return nil @@ -162,11 +90,10 @@ func (p *Pool) Release(ctx context.Context) (finalErr error) { close(p.done) var conns []closer.Closer - p.mtx.WithRLock(func() { - conns = make([]closer.Closer, 0, len(p.conns)) - for _, c := range p.conns { - conns = append(conns, c) - } + p.conns.Range(func(_ endpoint.Endpoint, cc *connWithCounter) bool { + conns = append(conns, cc.conn) + + return true }) var ( @@ -206,29 +133,20 @@ func (p *Pool) connParker(ctx context.Context, ttl, interval time.Duration) { case <-p.done: return case <-ticker.C: - for _, c := range p.collectConns() { - if time.Since(c.LastUsage()) > ttl { - switch c.GetState() { + p.conns.Range(func(key endpoint.Endpoint, c *connWithCounter) bool { + if time.Since(c.conn.LastUsage()) > ttl { + switch c.conn.GetState() { case Online, Banned: - _ = c.park(ctx) + _ = c.conn.park(ctx) default: // nop } } - } - } - } -} -func (p *Pool) collectConns() []*conn { - p.mtx.RLock() - defer p.mtx.RUnlock() - conns := make([]*conn, 0, len(p.conns)) - for _, c := range p.conns { - conns = append(conns, c) + return true + }) + } } - - return conns } func NewPool(ctx context.Context, config Config) *Pool { @@ -241,7 +159,6 @@ func NewPool(ctx context.Context, config Config) *Pool { usages: 1, config: config, opts: config.GrpcDialOptions(), - conns: make(map[connsKey]*conn), done: make(chan struct{}), }