Skip to content

Commit

Permalink
* Added internal/credentials.IsAccessError(err) helper for check ac…
Browse files Browse the repository at this point in the history
…cess errors
  • Loading branch information
asmyasnikov committed Oct 22, 2023
1 parent f56f07b commit 9836629
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 128 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Added `internal/credentials.IsAccessError(err)` helper for check access errors
* Changed period for re-fresh static credentials token from `1/2` to `1/10` to expiration time

## v3.53.4
Expand Down
13 changes: 5 additions & 8 deletions internal/balancer/balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"

"google.golang.org/grpc"
grpcCodes "google.golang.org/grpc/codes"

"github.com/ydb-platform/ydb-go-sdk/v3/config"
balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
Expand Down Expand Up @@ -66,8 +65,8 @@ func (b *Balancer) OnUpdate(onApplyDiscoveredEndpoints func(ctx context.Context,
func (b *Balancer) clusterDiscovery(ctx context.Context) (err error) {
if err = retry.Retry(ctx, func(childCtx context.Context) (err error) {
if err = b.clusterDiscoveryAttempt(childCtx); err != nil {
if xerrors.IsTransportError(err, grpcCodes.Unauthenticated) {
return credentials.UnauthenticatedError("cluster discovery failed", err,
if credentials.IsAccessError(err) {
return credentials.AccessError("cluster discovery failed", err,
credentials.WithEndpoint(b.driverConfig.Endpoint()),
credentials.WithDatabase(b.driverConfig.Database()),
credentials.WithCredentials(b.driverConfig.Credentials()),
Expand All @@ -90,9 +89,7 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
var (
address = "ydb:///" + b.driverConfig.Endpoint()
onDone = trace.DriverOnBalancerClusterDiscoveryAttempt(
b.driverConfig.Trace(),
&ctx,
address,
b.driverConfig.Trace(), &ctx, address,
)
endpoints []endpoint.Endpoint
localDC string
Expand Down Expand Up @@ -295,8 +292,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc

if err = f(ctx, cc); err != nil {
if conn.UseWrapping(ctx) {
if xerrors.IsTransportError(err, grpcCodes.Unauthenticated) {
err = credentials.UnauthenticatedError("unauthenticated", err,
if credentials.IsAccessError(err) {
err = credentials.AccessError("no access", err,
credentials.WithAddress(cc.Endpoint().String()),
credentials.WithNodeID(cc.Endpoint().NodeID()),
credentials.WithCredentials(b.driverConfig.Credentials()),
Expand Down
119 changes: 119 additions & 0 deletions internal/credentials/access_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package credentials

import (
"bytes"
"fmt"
"reflect"
"strconv"

"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
grpcCodes "google.golang.org/grpc/codes"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
)

type authErrorOption interface {
applyAuthErrorOption(buffer *bytes.Buffer)
}

var (
_ authErrorOption = nodeIDAuthErrorOption(0)
_ authErrorOption = addressAuthErrorOption("")
_ authErrorOption = endpointAuthErrorOption("")
_ authErrorOption = databaseAuthErrorOption("")
)

type addressAuthErrorOption string

func (address addressAuthErrorOption) applyAuthErrorOption(buffer *bytes.Buffer) {
buffer.WriteString("address:")
fmt.Fprintf(buffer, "%q", address)
}

func WithAddress(address string) addressAuthErrorOption {
return addressAuthErrorOption(address)
}

type endpointAuthErrorOption string

func (endpoint endpointAuthErrorOption) applyAuthErrorOption(buffer *bytes.Buffer) {
buffer.WriteString("endpoint:")
fmt.Fprintf(buffer, "%q", endpoint)
}

func WithEndpoint(endpoint string) endpointAuthErrorOption {
return endpointAuthErrorOption(endpoint)
}

type databaseAuthErrorOption string

func (address databaseAuthErrorOption) applyAuthErrorOption(buffer *bytes.Buffer) {
buffer.WriteString("database:")
fmt.Fprintf(buffer, "%q", address)
}

func WithDatabase(database string) databaseAuthErrorOption {
return databaseAuthErrorOption(database)
}

type nodeIDAuthErrorOption uint32

func (id nodeIDAuthErrorOption) applyAuthErrorOption(buffer *bytes.Buffer) {
buffer.WriteString("nodeID:")
buffer.WriteString(strconv.FormatUint(uint64(id), 10))
}

func WithNodeID(id uint32) authErrorOption {
return nodeIDAuthErrorOption(id)
}

type credentialsUnauthenticatedErrorOption struct {
credentials Credentials
}

func (opt credentialsUnauthenticatedErrorOption) applyAuthErrorOption(buffer *bytes.Buffer) {
buffer.WriteString("credentials:")
if stringer, has := opt.credentials.(fmt.Stringer); has {
fmt.Fprintf(buffer, "%q", stringer.String())
} else {
t := reflect.TypeOf(opt.credentials)
fmt.Fprintf(buffer, "%q", t.PkgPath()+"."+t.Name())
}
}

func WithCredentials(credentials Credentials) credentialsUnauthenticatedErrorOption {
return credentialsUnauthenticatedErrorOption{
credentials: credentials,
}
}

func AccessError(msg string, err error, opts ...authErrorOption) error {
buffer := allocator.Buffers.Get()
defer allocator.Buffers.Put(buffer)
buffer.WriteString(msg)
buffer.WriteString(" (")
for i, opt := range opts {
if i != 0 {
buffer.WriteString(",")
}
opt.applyAuthErrorOption(buffer)
}
buffer.WriteString("): %w")
return xerrors.WithStackTrace(fmt.Errorf(buffer.String(), err), xerrors.WithSkipDepth(1))
}

func IsAccessError(err error) bool {
if xerrors.IsTransportError(err,
grpcCodes.Unauthenticated,
grpcCodes.PermissionDenied,
) {
return true
}
if xerrors.IsOperationError(err,
Ydb.StatusIds_UNAUTHORIZED,
) {
return true
}
return false
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ import (
"testing"

"github.com/stretchr/testify/require"
"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
grpcCodes "google.golang.org/grpc/codes"
grpcStatus "google.golang.org/grpc/status"

"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
)

var _ Credentials = customCredentials{}
Expand All @@ -15,17 +20,17 @@ type customCredentials struct {
token string
}

func (c customCredentials) Token(ctx context.Context) (string, error) {
func (c customCredentials) Token(context.Context) (string, error) {
return c.token, nil
}

func TestUnauthenticatedError(t *testing.T) {
func TestAccessError(t *testing.T) {
for _, tt := range []struct {
err error
errorString string
}{
{
err: UnauthenticatedError(
err: AccessError(
"something went wrong",
errors.New("test"),
WithEndpoint("grps://localhost:2135"),
Expand All @@ -37,10 +42,10 @@ func TestUnauthenticatedError(t *testing.T) {
"database:\"/local\"," +
"credentials:\"Anonymous()\"" +
"): test " +
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestUnauthenticatedError(unauthenticated_error_test.go:28)`", //nolint:lll
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestAccessError(access_error_test.go:33)`", //nolint:lll
},
{
err: UnauthenticatedError(
err: AccessError(
"something went wrong",
errors.New("test"),
WithEndpoint("grps://localhost:2135"),
Expand All @@ -50,12 +55,12 @@ func TestUnauthenticatedError(t *testing.T) {
errorString: "something went wrong (" +
"endpoint:\"grps://localhost:2135\"," +
"database:\"/local\"," +
"credentials:\"Anonymous(from:\\\"TestUnauthenticatedError\\\")\"" +
"credentials:\"Anonymous(from:\\\"TestAccessError\\\")\"" +
"): test " +
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestUnauthenticatedError(unauthenticated_error_test.go:43)`", //nolint:lll
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestAccessError(access_error_test.go:48)`", //nolint:lll
},
{
err: UnauthenticatedError(
err: AccessError(
"something went wrong",
errors.New("test"),
WithEndpoint("grps://localhost:2135"),
Expand All @@ -67,10 +72,10 @@ func TestUnauthenticatedError(t *testing.T) {
"database:\"/local\"," +
"credentials:\"AccessToken(token:\\\"****(CRC-32c: 9B7801F4)\\\")\"" +
"): test " +
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestUnauthenticatedError(unauthenticated_error_test.go:58)`", //nolint:lll
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestAccessError(access_error_test.go:63)`", //nolint:lll
},
{
err: UnauthenticatedError(
err: AccessError(
"something went wrong",
errors.New("test"),
WithEndpoint("grps://localhost:2135"),
Expand All @@ -80,12 +85,12 @@ func TestUnauthenticatedError(t *testing.T) {
errorString: "something went wrong (" +
"endpoint:\"grps://localhost:2135\"," +
"database:\"/local\"," +
"credentials:\"AccessToken(token:\\\"****(CRC-32c: 9B7801F4)\\\",from:\\\"TestUnauthenticatedError\\\")\"" + //nolint:lll
"credentials:\"AccessToken(token:\\\"****(CRC-32c: 9B7801F4)\\\",from:\\\"TestAccessError\\\")\"" +
"): test " +
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestUnauthenticatedError(unauthenticated_error_test.go:73)`", //nolint:lll
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestAccessError(access_error_test.go:78)`", //nolint:lll
},
{
err: UnauthenticatedError(
err: AccessError(
"something went wrong",
errors.New("test"),
WithEndpoint("grps://localhost:2135"),
Expand All @@ -101,10 +106,10 @@ func TestUnauthenticatedError(t *testing.T) {
"database:\"/local\"," +
"credentials:\"Static(user:\\\"USER\\\",password:\\\"SEC**********RD\\\",token:\\\"****(CRC-32c: 00000000)\\\")\"" + //nolint:lll
"): test " +
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestUnauthenticatedError(unauthenticated_error_test.go:88)`", //nolint:lll
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestAccessError(access_error_test.go:93)`", //nolint:lll
},
{
err: UnauthenticatedError(
err: AccessError(
"something went wrong",
errors.New("test"),
WithEndpoint("grps://localhost:2135"),
Expand All @@ -118,12 +123,12 @@ func TestUnauthenticatedError(t *testing.T) {
errorString: "something went wrong (" +
"endpoint:\"grps://localhost:2135\"," +
"database:\"/local\"," +
"credentials:\"Static(user:\\\"USER\\\",password:\\\"SEC**********RD\\\",token:\\\"****(CRC-32c: 00000000)\\\",from:\\\"TestUnauthenticatedError\\\")\"" + //nolint:lll
"credentials:\"Static(user:\\\"USER\\\",password:\\\"SEC**********RD\\\",token:\\\"****(CRC-32c: 00000000)\\\",from:\\\"TestAccessError\\\")\"" + //nolint:lll
"): test " +
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestUnauthenticatedError(unauthenticated_error_test.go:107)`", //nolint:lll
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestAccessError(access_error_test.go:112)`", //nolint:lll
},
{
err: UnauthenticatedError(
err: AccessError(
"something went wrong",
errors.New("test"),
WithEndpoint("grps://localhost:2135"),
Expand All @@ -135,10 +140,10 @@ func TestUnauthenticatedError(t *testing.T) {
"database:\"/local\"," +
"credentials:\"github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.customCredentials\"" +
"): test " +
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestUnauthenticatedError(unauthenticated_error_test.go:126)`", //nolint:lll
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestAccessError(access_error_test.go:131)`", //nolint:lll
},
{
err: UnauthenticatedError(
err: AccessError(
"something went wrong",
errors.New("test"),
WithEndpoint("grps://localhost:2135"),
Expand All @@ -150,7 +155,7 @@ func TestUnauthenticatedError(t *testing.T) {
"database:\"/local\"," +
"credentials:\"Anonymous()\"" +
"): test " +
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestUnauthenticatedError(unauthenticated_error_test.go:141)`", //nolint:lll
"at `github.com/ydb-platform/ydb-go-sdk/v3/internal/credentials.TestAccessError(access_error_test.go:146)`", //nolint:lll
},
} {
t.Run("", func(t *testing.T) {
Expand All @@ -162,3 +167,39 @@ func TestUnauthenticatedError(t *testing.T) {
func TestWrongStringifyCustomCredentials(t *testing.T) {
require.Equal(t, "&{\"SECRET_TOKEN\"}", fmt.Sprintf("%q", &customCredentials{token: "SECRET_TOKEN"}))
}

func TestIsAccessError(t *testing.T) {
for _, tt := range []struct {
error error
is bool
}{
{
error: grpcStatus.Error(grpcCodes.PermissionDenied, ""),
is: true,
},
{
error: grpcStatus.Error(grpcCodes.Unauthenticated, ""),
is: true,
},
{
error: xerrors.Transport(grpcStatus.Error(grpcCodes.PermissionDenied, "")),
is: true,
},
{
error: xerrors.Transport(grpcStatus.Error(grpcCodes.Unauthenticated, "")),
is: true,
},
{
error: xerrors.Operation(xerrors.WithStatusCode(Ydb.StatusIds_UNAUTHORIZED)),
is: true,
},
{
error: errors.New("some error"),
is: false,
},
} {
t.Run("", func(t *testing.T) {
require.Equal(t, tt.is, IsAccessError(tt.error), tt.error)
})
}
}
Loading

0 comments on commit 9836629

Please sign in to comment.