Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal committed Aug 23, 2024
1 parent 02267d3 commit a3049fe
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 45 deletions.
2 changes: 1 addition & 1 deletion balancer/pickfirst/pickfirst.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/serviceconfig"

_ "google.golang.org/grpc/balancer/pickfirst_leaf" // For automatically registering the new pickfirst if required.
_ "google.golang.org/grpc/balancer/pickfirstleaf" // For automatically registering the new pickfirst if required.
)

func init() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (pickfirstBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) b
ctx, cancel := context.WithCancel(context.Background())
b := &pickfirstBalancer{
cc: cc,
addressIndex: addressList{},
addressList: addressList{},
subConns: resolver.NewAddressMap(),
serializer: grpcsync.NewCallbackSerializer(ctx),
serializerCancel: cancel,
Expand Down Expand Up @@ -142,7 +142,7 @@ type pickfirstBalancer struct {
serializerCancel func()
state connectivity.State
subConns *resolver.AddressMap // scData for active subonns mapped by address.
addressIndex addressList
addressList addressList
firstPass bool
}

Expand All @@ -166,7 +166,7 @@ func (b *pickfirstBalancer) resolverError(err error) {
}

b.closeSubConns()
b.addressIndex.updateEndpointList(nil)
b.addressList.updateEndpointList(nil)
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)},
Expand Down Expand Up @@ -237,10 +237,10 @@ func (b *pickfirstBalancer) updateClientConnState(state balancer.ClientConnState

// If the previous ready subconn exists in new address list,
// keep this connection and don't create new subconns.
prevAddr := b.addressIndex.currentAddress()
prevAddrsCount := b.addressIndex.size()
b.addressIndex.updateEndpointList(newEndpoints)
if b.state == connectivity.Ready && b.addressIndex.seekTo(prevAddr) {
prevAddr := b.addressList.currentAddress()
prevAddrsCount := b.addressList.size()
b.addressList.updateEndpointList(newEndpoints)
if b.state == connectivity.Ready && b.addressList.seekTo(prevAddr) {
return nil
}

Expand Down Expand Up @@ -365,16 +365,15 @@ func (b *pickfirstBalancer) shutdownRemaining(selected *scData) {
b.subConns.Set(selected.addr, selected)
}

// requestConnection requests a connection to the next applicable address'
// subcon, creating one if necessary. Schedules a connection to next address in list as well.
// If the current channel has already attempted a connection, we attempt a connection
// to the next address/subconn in our list. We assume that NewSubConn will never
// return an error.
// requestConnection starts connecting on the subchannel corresponding to the
// current address. If no subchannel exists, one is created. If the current
// subchannel is in TransientFailure, a connection to the next address is
// attempted.
func (b *pickfirstBalancer) requestConnection() {
if !b.addressIndex.isValid() || b.state == connectivity.Shutdown {
if !b.addressList.isValid() || b.state == connectivity.Shutdown {
return
}
curAddr := b.addressIndex.currentAddress()
curAddr := b.addressList.currentAddress()
sd, ok := b.subConns.Get(curAddr)
if !ok {
var err error
Expand All @@ -388,7 +387,7 @@ func (b *pickfirstBalancer) requestConnection() {
// The LB policy remains in TRANSIENT_FAILURE until a new resolver
// update is received.
b.state = connectivity.TransientFailure
b.addressIndex.reset()
b.addressList.reset()
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.TransientFailure,
Picker: &picker{err: fmt.Errorf("failed to create a new subConn: %v", err)},
Expand All @@ -403,7 +402,7 @@ func (b *pickfirstBalancer) requestConnection() {
case connectivity.Idle:
scd.subConn.Connect()
case connectivity.TransientFailure:
if !b.addressIndex.increment() {
if !b.addressList.increment() {
b.endFirstPass(scd.lastErr)
return
}
Expand All @@ -430,7 +429,12 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon

if state.ConnectivityState == connectivity.Ready {
b.shutdownRemaining(sd)
b.addressIndex.seekTo(sd.addr)
if !b.addressList.seekTo(sd.addr) {
// This should not fail as we should have only one subconn after
// entering READY. The subconn should be present in the addressList.
b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses)
return
}
b.state = connectivity.Ready
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Ready,
Expand All @@ -442,10 +446,10 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon
// If we are transitioning from READY to IDLE, reset index and re-connect when
// prompted.
if b.state == connectivity.Ready && state.ConnectivityState == connectivity.Idle {
// Once a transport fails, we enter idle and start from the first address
// when the picker is used.
// Once a transport fails, the balancer enters IDLE and starts from
// the first address when the picker is used.
b.state = connectivity.Idle
b.addressIndex.reset()
b.addressList.reset()
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Idle,
Picker: &idlePicker{exitIdle: b.ExitIdle},
Expand All @@ -456,10 +460,10 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon
if b.firstPass {
switch state.ConnectivityState {
case connectivity.Connecting:
// We can be in either IDLE, CONNECTING or TRANSIENT_FAILURE.
// If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until
// we're READY. See A62.
// If we're already in CONNECTING, no update is needed.
// The balancer can be in either IDLE, CONNECTING or TRANSIENT_FAILURE.
// If it's in TRANSIENT_FAILURE, stay in TRANSIENT_FAILURE until
// it's READY. See A62.
// If the balancer is already in CONNECTING, no update is needed.
if b.state == connectivity.Idle {
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Connecting,
Expand All @@ -471,11 +475,11 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon
// Since we're re-using common subconns while handling resolver updates,
// we could receive an out of turn TRANSIENT_FAILURE from a pass
// over the previous address list. We ignore such updates.
curAddr := b.addressIndex.currentAddress()
if activeSD, found := b.subConns.Get(curAddr); !found || activeSD != sd {

if curAddr := b.addressList.currentAddress(); !equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) {
return
}
if b.addressIndex.increment() {
if b.addressList.increment() {
b.requestConnection()
return
}
Expand All @@ -486,12 +490,11 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon
// a transport is successfully created, but the connection fails before
// the subconn can send the notification for READY. We treat this
// as a successful connection and transition to IDLE.
curAddr := b.addressIndex.currentAddress()
if activeSD, found := b.subConns.Get(curAddr); !found || activeSD != sd {
if curAddr := b.addressList.currentAddress(); !equalAddressIgnoringBalAttributes(&sd.addr, &curAddr) {
return
}
b.state = connectivity.Idle
b.addressIndex.reset()
b.addressList.reset()
b.cc.UpdateState(balancer.State{
ConnectivityState: connectivity.Idle,
Picker: &idlePicker{exitIdle: b.ExitIdle},
Expand All @@ -510,6 +513,7 @@ func (b *pickfirstBalancer) updateSubConnState(sd *scData, state balancer.SubCon
})
// We don't need to request re-resolution since the subconn already does
// that before reporting TRANSIENT_FAILURE.
// TODO: #7534 - Move re-resolution requests from subconn into pick_first.
case connectivity.Idle:
sd.subConn.Connect()
}
Expand Down Expand Up @@ -569,8 +573,7 @@ func (al *addressList) size() int {
return len(al.addresses)
}

// increment moves to the next index in the address list. If at the last address
// in the address list, moves to the next endpoint in the endpoint list.
// increment moves to the next index in the address list.
// This method returns false if it went off the list, true otherwise.
func (al *addressList) increment() bool {
if !al.isValid() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,38 +72,39 @@ func (s) TestAddressList_Iteration(t *testing.T) {

for i := 0; i < len(addrs); i++ {
if got, want := addressList.isValid(), true; got != want {
t.Errorf("addressList.isValid() = %t, want %t", got, want)
t.Fatalf("addressList.isValid() = %t, want %t", got, want)
}
if got, want := addressList.currentAddress(), addrs[i]; !want.Equal(got) {
t.Errorf("addressList.currentAddress() = %v, want %v", got, want)
}
if got, want := addressList.increment(), i+1 < len(addrs); got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
t.Fatalf("addressList.increment() = %t, want %t", got, want)
}
}

if got, want := addressList.isValid(), false; got != want {
t.Errorf("addressList.isValid() = %t, want %t", got, want)
t.Fatalf("addressList.isValid() = %t, want %t", got, want)
}

// increment an invalid address list.
if got, want := addressList.increment(), false; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}

if got, want := addressList.isValid(), false; got != want {
t.Errorf("addressList.isValid() = %t, want %t", got, want)
}

addressList.reset()
for i := 0; i < len(addrs); i++ {
if got, want := addressList.isValid(), true; got != want {
t.Errorf("addressList.isValid() = %t, want %t", got, want)
t.Fatalf("addressList.isValid() = %t, want %t", got, want)
}
if got, want := addressList.currentAddress(), addrs[i]; !want.Equal(got) {
t.Errorf("addressList.currentAddress() = %v, want %v", got, want)
}
if got, want := addressList.increment(), i+1 < len(addrs); got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
t.Fatalf("addressList.increment() = %t, want %t", got, want)
}
}
}
Expand Down Expand Up @@ -159,6 +160,7 @@ func (s) TestAddressList_SeekTo(t *testing.T) {
if got, want := addressList.increment(), true; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}

if got, want := addressList.increment(), false; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
Expand All @@ -170,20 +172,17 @@ func (s) TestAddressList_SeekTo(t *testing.T) {

// Seek to a key not in the list.
key = resolver.Address{

Check failure on line 174 in balancer/pickfirstleaf/pickfirstleaf_test.go

View workflow job for this annotation

GitHub Actions / tests (vet, 1.22)

this value of key is never used (SA4006)
Addr: "192.168.1.2",
Addr: "192.168.1.5",
ServerName: "test-host-5",
Attributes: attributes.New("key-5", "val-5"),
BalancerAttributes: attributes.New("ignored", "bal-val-5"),
}
// Seek to the key again, it is behind the pointer now.
if got, want := addressList.seekTo(key), false; got != want {
t.Errorf("addressList.seekTo(%v) = %t, want %t", key, got, want)
}

// It should be possible to increment once more since the pointer has not advanced.
if got, want := addressList.increment(), true; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}

if got, want := addressList.increment(), false; got != want {
t.Errorf("addressList.increment() = %t, want %t", got, want)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@ import (
"fmt"
"sync"
"testing"
"time"

"github.com/google/go-cmp/cmp"

"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
pickfirstleaf "google.golang.org/grpc/balancer/pickfirst_leaf"
"google.golang.org/grpc/balancer/pickfirstleaf"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils/pickfirst"
"google.golang.org/grpc/resolver"
Expand All @@ -42,9 +44,24 @@ import (
testpb "google.golang.org/grpc/interop/grpc_testing"
)

const (
// Default timeout for tests in this package.
defaultTestTimeout = 10 * time.Second
// Default short timeout, to be used when waiting for events which are not
// expected to happen.
defaultTestShortTimeout = 100 * time.Millisecond
stateStoringBalancerName = "state_storing"
)

var stateStoringServiceConfig = fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, stateStoringBalancerName)

const stateStoringBalancerName = "state_storing"
type s struct {
grpctest.Tester
}

func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}

// setupPickFirstLeaf performs steps required for pick_first tests. It starts a
// bunch of backends exporting the TestService, creates a ClientConn to them
Expand Down Expand Up @@ -503,6 +520,16 @@ func (s) TestPickFirstLeaf_EmptyAddressList(t *testing.T) {
}
}

// stubBackendsToResolverAddrs converts from a set of stub server backends to
// resolver addresses. Useful when pushing addresses to the manual resolver.
func stubBackendsToResolverAddrs(backends []*stubserver.StubServer) []resolver.Address {
addrs := make([]resolver.Address, len(backends))
for i, backend := range backends {
addrs[i] = resolver.Address{Addr: backend.Address}
}
return addrs
}

// stateStoringBalancer stores the state of the subconns being created.
type stateStoringBalancer struct {
balancer.Balancer
Expand Down

0 comments on commit a3049fe

Please sign in to comment.