Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
asmyasnikov committed Mar 29, 2024
1 parent a0abd92 commit 5455fcc
Show file tree
Hide file tree
Showing 38 changed files with 1,378 additions and 980 deletions.
10 changes: 5 additions & 5 deletions balancers/balancers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func SingleConn() *balancerConfig.Config {

type filterLocalDC struct{}

func (filterLocalDC) Allow(info balancerConfig.Info, c conn.Conn) bool {
func (filterLocalDC) Allow(info balancerConfig.Info, c conn.Info) bool {
return c.Endpoint().Location() == info.SelfLocation
}

Expand Down Expand Up @@ -56,7 +56,7 @@ func PreferLocalDCWithFallBack(balancer *balancerConfig.Config) *balancerConfig.

type filterLocations []string

func (locations filterLocations) Allow(_ balancerConfig.Info, c conn.Conn) bool {
func (locations filterLocations) Allow(_ balancerConfig.Info, c conn.Info) bool {
location := strings.ToUpper(c.Endpoint().Location())
for _, l := range locations {
if location == l {
Expand Down Expand Up @@ -118,9 +118,9 @@ type Endpoint interface {
LocalDC() bool
}

type filterFunc func(info balancerConfig.Info, c conn.Conn) bool
type filterFunc func(info balancerConfig.Info, c conn.Info) bool

func (p filterFunc) Allow(info balancerConfig.Info, c conn.Conn) bool {
func (p filterFunc) Allow(info balancerConfig.Info, c conn.Info) bool {
return p(info, c)
}

Expand All @@ -131,7 +131,7 @@ func (p filterFunc) String() string {
// Prefer creates balancer which use endpoints by filter
// Balancer "balancer" defines balancing algorithm between endpoints selected with filter
func Prefer(balancer *balancerConfig.Config, filter func(endpoint Endpoint) bool) *balancerConfig.Config {
balancer.Filter = filterFunc(func(_ balancerConfig.Info, c conn.Conn) bool {
balancer.Filter = filterFunc(func(_ balancerConfig.Info, c conn.Info) bool {
return filter(c.Endpoint())
})

Expand Down
46 changes: 23 additions & 23 deletions balancers/balancers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,56 +11,56 @@ import (
)

func TestPreferLocalDC(t *testing.T) {
conns := []conn.Conn{
&mock.Conn{AddrField: "1", LocationField: "1"},
&mock.Conn{AddrField: "2", State: conn.Online, LocationField: "2"},
&mock.Conn{AddrField: "3", State: conn.Online, LocationField: "2"},
conns := []conn.Info{
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "1"},
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "2"},
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "2"},
}
rr := PreferLocalDC(RandomChoice())
require.False(t, rr.AllowFallback)
require.Equal(t, []conn.Conn{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns))
require.Equal(t, []conn.Info{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns))
}

func TestPreferLocalDCWithFallBack(t *testing.T) {
conns := []conn.Conn{
&mock.Conn{AddrField: "1", LocationField: "1"},
&mock.Conn{AddrField: "2", State: conn.Online, LocationField: "2"},
&mock.Conn{AddrField: "3", State: conn.Online, LocationField: "2"},
conns := []conn.Info{
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "1"},
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "2"},
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "2"},
}
rr := PreferLocalDCWithFallBack(RandomChoice())
require.True(t, rr.AllowFallback)
require.Equal(t, []conn.Conn{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns))
require.Equal(t, []conn.Info{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns))
}

func TestPreferLocations(t *testing.T) {
conns := []conn.Conn{
&mock.Conn{AddrField: "1", LocationField: "zero", State: conn.Online},
&mock.Conn{AddrField: "2", State: conn.Online, LocationField: "one"},
&mock.Conn{AddrField: "3", State: conn.Online, LocationField: "two"},
conns := []conn.Info{
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "zero", ConnState: conn.Online},
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "one"},
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "two"},
}

rr := PreferLocations(RandomChoice(), "zero", "two")
require.False(t, rr.AllowFallback)
require.Equal(t, []conn.Conn{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns))
require.Equal(t, []conn.Info{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns))
}

func TestPreferLocationsWithFallback(t *testing.T) {
conns := []conn.Conn{
&mock.Conn{AddrField: "1", LocationField: "zero", State: conn.Online},
&mock.Conn{AddrField: "2", State: conn.Online, LocationField: "one"},
&mock.Conn{AddrField: "3", State: conn.Online, LocationField: "two"},
conns := []conn.Info{
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "zero", ConnState: conn.Online},
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "one"},
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "two"},
}

rr := PreferLocationsWithFallback(RandomChoice(), "zero", "two")
require.True(t, rr.AllowFallback)
require.Equal(t, []conn.Conn{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns))
require.Equal(t, []conn.Info{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns))
}

func applyPreferFilter(info balancerConfig.Info, b *balancerConfig.Config, conns []conn.Conn) []conn.Conn {
func applyPreferFilter(info balancerConfig.Info, b *balancerConfig.Config, conns []conn.Info) []conn.Info {
if b.Filter == nil {
b.Filter = filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { return true })
b.Filter = filterFunc(func(info balancerConfig.Info, c conn.Info) bool { return true })
}
res := make([]conn.Conn, 0, len(conns))
res := make([]conn.Info, 0, len(conns))
for _, c := range conns {
if b.Filter.Allow(info, c) {
res = append(res, c)
Expand Down
8 changes: 4 additions & 4 deletions balancers/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestFromConfig(t *testing.T) {
}`,
res: balancerConfig.Config{
DetectLocalDC: true,
Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool {
Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool {
// some non nil func
return false
}),
Expand All @@ -95,7 +95,7 @@ func TestFromConfig(t *testing.T) {
res: balancerConfig.Config{
AllowFallback: true,
DetectLocalDC: true,
Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool {
Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool {
// some non nil func
return false
}),
Expand All @@ -109,7 +109,7 @@ func TestFromConfig(t *testing.T) {
"locations": ["AAA", "BBB", "CCC"]
}`,
res: balancerConfig.Config{
Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool {
Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool {
// some non nil func
return false
}),
Expand All @@ -125,7 +125,7 @@ func TestFromConfig(t *testing.T) {
}`,
res: balancerConfig.Config{
AllowFallback: true,
Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool {
Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool {
// some non nil func
return false
}),
Expand Down
2 changes: 1 addition & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func (d *Driver) Close(ctx context.Context) (finalErr error) {
d.query.Close,
d.topic.Close,
d.balancer.Close,
d.pool.Release,
d.pool.Detach,
)

var issues []error
Expand Down
66 changes: 43 additions & 23 deletions internal/balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ type Balancer struct {
localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error)

mu xsync.RWMutex
connectionsState *connectionsState
connectionsState *connectionsState[conn.Conn]

closed chan struct{}

onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info)
}
Expand Down Expand Up @@ -133,7 +135,7 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
return nil
}

func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Conn) (
func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Info) (
nodes []trace.EndpointInfo,
added []trace.EndpointInfo,
dropped []trace.EndpointInfo,
Expand Down Expand Up @@ -178,7 +180,7 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
b.config.DetectLocalDC,
)
previousConns []conn.Conn
previousConns []conn.Info
)
defer func() {
nodes, added, dropped := endpointsDiff(endpoints, previousConns)
Expand All @@ -187,7 +189,9 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end

connections := endpointsToConnections(b.pool, endpoints)
for _, c := range connections {
b.pool.Allow(ctx, c)
if c.State() == conn.Banned {
b.pool.Unban(ctx, c)
}
c.Endpoint().Touch()
}

Expand All @@ -201,7 +205,10 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end

b.mu.WithLock(func() {
if b.connectionsState != nil {
previousConns = b.connectionsState.all
previousConns = make([]conn.Info, len(b.connectionsState.all))
for i := range b.connectionsState.all {
previousConns[i] = b.connectionsState.all[i]
}
}
b.connectionsState = state
for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints {
Expand All @@ -211,6 +218,8 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
}

func (b *Balancer) Close(ctx context.Context) (err error) {
close(b.closed)

onDone := trace.DriverOnBalancerClose(
b.driverConfig.Trace(), &ctx,
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).Close"),
Expand All @@ -223,6 +232,8 @@ func (b *Balancer) Close(ctx context.Context) (err error) {
b.discoveryRepeater.Stop()
}

b.applyDiscoveredEndpoints(ctx, nil, "")

if err = b.discoveryClient.Close(ctx); err != nil {
return xerrors.WithStackTrace(err)
}
Expand Down Expand Up @@ -258,6 +269,7 @@ func New(
driverConfig: driverConfig,
pool: pool,
localDCDetector: detectLocalDC,
closed: make(chan struct{}),
}
d := internalDiscovery.New(ctx, pool.Get(
endpoint.New(driverConfig.Endpoint()),
Expand Down Expand Up @@ -300,9 +312,14 @@ func (b *Balancer) Invoke(
reply interface{},
opts ...grpc.CallOption,
) error {
return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
return cc.Invoke(ctx, method, args, reply, opts...)
})
select {
case <-b.closed:
return xerrors.WithStackTrace(errBalancerClosed)
default:
return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
return cc.Invoke(ctx, method, args, reply, opts...)
})
}
}

func (b *Balancer) NewStream(
Expand All @@ -311,17 +328,22 @@ func (b *Balancer) NewStream(
method string,
opts ...grpc.CallOption,
) (_ grpc.ClientStream, err error) {
var client grpc.ClientStream
err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
client, err = cc.NewStream(ctx, desc, method, opts...)
select {
case <-b.closed:
return nil, xerrors.WithStackTrace(errBalancerClosed)
default:
var client grpc.ClientStream
err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
client, err = cc.NewStream(ctx, desc, method, opts...)

return err
})
if err == nil {
return client, nil
}

return err
})
if err == nil {
return client, nil
return nil, err
}

return nil, err
}

func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc conn.Conn) error) (err error) {
Expand All @@ -332,10 +354,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc

defer func() {
if err == nil {
if cc.GetState() == conn.Banned {
b.pool.Allow(ctx, cc)
}
} else if xerrors.MustPessimizeEndpoint(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
b.pool.Unban(ctx, cc)
} else if xerrors.MustBanConn(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
b.pool.Ban(ctx, cc, err)
}
}()
Expand Down Expand Up @@ -363,7 +383,7 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
return nil
}

func (b *Balancer) connections() *connectionsState {
func (b *Balancer) connections() *connectionsState[conn.Conn] {
b.mu.RLock()
defer b.mu.RUnlock()

Expand Down Expand Up @@ -401,7 +421,7 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
c, failedCount = state.GetConnection(ctx)
if c == nil {
return nil, xerrors.WithStackTrace(
fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount),
fmt.Errorf("cannot get connection from Balancer after %d attempts: %w", failedCount, ErrNoEndpoints),
)
}

Expand Down
38 changes: 19 additions & 19 deletions internal/balancer/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
func TestEndpointsDiff(t *testing.T) {
for _, tt := range []struct {
newestEndpoints []endpoint.Endpoint
previousConns []conn.Conn
previousConns []conn.Info
nodes []trace.EndpointInfo
added []trace.EndpointInfo
dropped []trace.EndpointInfo
Expand All @@ -27,11 +27,11 @@ func TestEndpointsDiff(t *testing.T) {
&mock.Endpoint{AddrField: "2"},
&mock.Endpoint{AddrField: "0"},
},
previousConns: []conn.Conn{
&mock.Conn{AddrField: "2"},
&mock.Conn{AddrField: "1"},
&mock.Conn{AddrField: "0"},
&mock.Conn{AddrField: "3"},
previousConns: []conn.Info{
&mock.ConnInfo{EndpointAddrField: "2"},
&mock.ConnInfo{EndpointAddrField: "1"},
&mock.ConnInfo{EndpointAddrField: "0"},
&mock.ConnInfo{EndpointAddrField: "3"},
},
nodes: []trace.EndpointInfo{
&mock.Endpoint{AddrField: "0"},
Expand All @@ -49,10 +49,10 @@ func TestEndpointsDiff(t *testing.T) {
&mock.Endpoint{AddrField: "2"},
&mock.Endpoint{AddrField: "0"},
},
previousConns: []conn.Conn{
&mock.Conn{AddrField: "1"},
&mock.Conn{AddrField: "0"},
&mock.Conn{AddrField: "3"},
previousConns: []conn.Info{
&mock.ConnInfo{EndpointAddrField: "1"},
&mock.ConnInfo{EndpointAddrField: "0"},
&mock.ConnInfo{EndpointAddrField: "3"},
},
nodes: []trace.EndpointInfo{
&mock.Endpoint{AddrField: "0"},
Expand All @@ -71,11 +71,11 @@ func TestEndpointsDiff(t *testing.T) {
&mock.Endpoint{AddrField: "3"},
&mock.Endpoint{AddrField: "0"},
},
previousConns: []conn.Conn{
&mock.Conn{AddrField: "1"},
&mock.Conn{AddrField: "2"},
&mock.Conn{AddrField: "0"},
&mock.Conn{AddrField: "3"},
previousConns: []conn.Info{
&mock.ConnInfo{EndpointAddrField: "1"},
&mock.ConnInfo{EndpointAddrField: "2"},
&mock.ConnInfo{EndpointAddrField: "0"},
&mock.ConnInfo{EndpointAddrField: "3"},
},
nodes: []trace.EndpointInfo{
&mock.Endpoint{AddrField: "0"},
Expand All @@ -93,10 +93,10 @@ func TestEndpointsDiff(t *testing.T) {
&mock.Endpoint{AddrField: "3"},
&mock.Endpoint{AddrField: "0"},
},
previousConns: []conn.Conn{
&mock.Conn{AddrField: "4"},
&mock.Conn{AddrField: "7"},
&mock.Conn{AddrField: "8"},
previousConns: []conn.Info{
&mock.ConnInfo{EndpointAddrField: "4"},
&mock.ConnInfo{EndpointAddrField: "7"},
&mock.ConnInfo{EndpointAddrField: "8"},
},
nodes: []trace.EndpointInfo{
&mock.Endpoint{AddrField: "0"},
Expand Down
Loading

0 comments on commit 5455fcc

Please sign in to comment.