Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close session in go client close #435

Merged
merged 2 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions oxia/async_client_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,14 @@ func NewAsyncClient(serviceAddress string, opts ...ClientOption) (AsyncClient, e
}

func (c *clientImpl) Close() error {
c.cancel()

return multierr.Combine(
err := multierr.Combine(
c.sessions.Close(),
c.writeBatchManager.Close(),
c.readBatchManager.Close(),
c.clientPool.Close(),
)
c.cancel()
return err
}

func (c *clientImpl) Put(key string, value []byte, options ...PutOption) <-chan PutResult {
Expand Down
49 changes: 44 additions & 5 deletions oxia/async_client_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
)

func init() {
common.LogJSON = false
common.ConfigureLogger()
}

Expand Down Expand Up @@ -295,9 +296,8 @@ func TestAsyncClientImpl_OverrideEphemeral(t *testing.T) {
}

func TestAsyncClientImpl_ClientIdentity(t *testing.T) {
identity1 := newKey()
client1, err := NewSyncClient(serviceAddress,
WithIdentity(identity1),
WithIdentity("client-1"),
)
assert.NoError(t, err)

Expand All @@ -306,10 +306,11 @@ func TestAsyncClientImpl_ClientIdentity(t *testing.T) {
assert.NoError(t, err)

assert.True(t, version.Ephemeral)
assert.Equal(t, identity1, version.ClientIdentity)
assert.Equal(t, "client-1", version.ClientIdentity)

client2, err := NewSyncClient(serviceAddress,
WithSessionTimeout(2*time.Second),
WithIdentity("client-2"),
)
assert.NoError(t, err)

Expand All @@ -319,14 +320,52 @@ func TestAsyncClientImpl_ClientIdentity(t *testing.T) {
assert.EqualValues(t, 0, version.ModificationsCount)
assert.Equal(t, "v1", string(res))
assert.True(t, version.Ephemeral)
assert.Equal(t, identity1, version.ClientIdentity)
assert.Equal(t, "client-1", version.ClientIdentity)

version, err = client2.Put(context.Background(), k, []byte("v2"), Ephemeral())
assert.NoError(t, err)

assert.True(t, version.Ephemeral)
assert.NotSame(t, "", version.ClientIdentity)
assert.Equal(t, "client-2", version.ClientIdentity)

assert.NoError(t, client1.Close())
assert.NoError(t, client2.Close())
}

func TestSyncClientImpl_SessionNotifications(t *testing.T) {
standaloneServer, err := server.NewStandalone(server.NewTestConfig(t.TempDir()))
assert.NoError(t, err)

serviceAddress := fmt.Sprintf("localhost:%d", standaloneServer.RpcPort())
client1, err := NewSyncClient(serviceAddress, WithIdentity("client-1"))
assert.NoError(t, err)

client2, err := NewSyncClient(serviceAddress, WithIdentity("client-1"))
assert.NoError(t, err)

notifications, err := client2.GetNotifications()
assert.NoError(t, err)

ctx := context.Background()

s1, _ := client1.Put(ctx, "/a", []byte("0"), Ephemeral())

n := <-notifications.Ch()
assert.Equal(t, KeyCreated, n.Type)
assert.Equal(t, "/a", n.Key)
assert.Equal(t, s1.VersionId, n.VersionId)

err = client1.Close()
assert.NoError(t, err)

select {
case n = <-notifications.Ch():
assert.Equal(t, KeyDeleted, n.Type)
assert.Equal(t, "/a", n.Key)
case <-time.After(3 * time.Second):
assert.Fail(t, "read from channel timed out")
}

assert.NoError(t, client2.Close())
assert.NoError(t, standaloneServer.Close())
}
55 changes: 39 additions & 16 deletions oxia/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"sync"
"time"

"go.uber.org/multierr"

"github.com/cenkalti/backoff/v4"
"google.golang.org/grpc/status"

Expand All @@ -40,6 +42,7 @@ func newSessions(ctx context.Context, shardManager internal.ShardManager, pool c
clientOpts: options,
log: slog.With(
slog.String("component", "oxia-session-manager"),
slog.String("client-identity", options.identity),
),
}
return s
Expand Down Expand Up @@ -71,13 +74,15 @@ func (s *sessions) startSession(shardId int64) *clientSession {
cs := &clientSession{
shardId: shardId,
sessions: s,
ctx: s.ctx,
started: make(chan error),
log: slog.With(
slog.String("component", "session"),
slog.Int64("shard", shardId),
),
}

cs.ctx, cs.cancel = context.WithCancel(s.ctx)

cs.log.Debug("Creating session")
go common.DoWithLabels(
cs.ctx,
Expand All @@ -90,6 +95,15 @@ func (s *sessions) startSession(shardId int64) *clientSession {
return cs
}

func (s *sessions) Close() error {
var err error
for _, cs := range s.sessionsByShard {
err = multierr.Append(err, cs.Close())
}

return err
}

type clientSession struct {
sync.Mutex
started chan error
Expand All @@ -98,6 +112,7 @@ type clientSession struct {
log *slog.Logger
sessions *sessions
ctx context.Context
cancel context.CancelFunc
}

func (cs *clientSession) executeWithId(callback func(int64, error)) {
Expand Down Expand Up @@ -163,6 +178,7 @@ func (cs *clientSession) createSession() error {
cs.sessionId = sessionId
cs.log = cs.log.With(
slog.Int64("session-id", sessionId),
slog.String("client-identity", cs.sessions.clientIdentity),
)
close(cs.started)
cs.log.Debug("Successfully created session")
Expand Down Expand Up @@ -200,7 +216,7 @@ func (cs *clientSession) createSession() error {
)
})

if !errors.Is(err, context.Canceled) {
if err != nil && !errors.Is(err, context.Canceled) {
cs.log.Error(
"Failed to keep alive session",
slog.Any("error", err),
Expand All @@ -217,11 +233,30 @@ func (cs *clientSession) getRpc() (proto.OxiaClientClient, error) {
return cs.sessions.pool.GetClientRpc(leader)
}

func (cs *clientSession) Close() error {
cs.cancel()

rpc, err := cs.getRpc()
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(cs.sessions.ctx, cs.sessions.clientOpts.requestTimeout)
defer cancel()

if _, err = rpc.CloseSession(ctx, &proto.CloseSessionRequest{
ShardId: cs.shardId,
SessionId: cs.sessionId,
}); err != nil {
return err
}
return nil
}

func (cs *clientSession) keepAlive() error {
cs.sessions.Lock()
cs.Lock()
timeout := cs.sessions.clientOpts.sessionTimeout
ctx := cs.sessions.ctx
ctx := cs.ctx
shardId := cs.shardId
sessionId := cs.sessionId
cs.Unlock()
Expand All @@ -248,19 +283,7 @@ func (cs *clientSession) keepAlive() error {
return err
}
case <-ctx.Done():
ctx, cancel := context.WithTimeout(context.Background(), cs.sessions.clientOpts.requestTimeout)
rpc, err = cs.getRpc()
if err != nil {
cancel()
return err
}
_, err = rpc.CloseSession(ctx, &proto.CloseSessionRequest{
ShardId: shardId,
SessionId: sessionId,
})

cancel()
return err
return nil
}
}
}
7 changes: 7 additions & 0 deletions server/kv/notifications_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"fmt"
"log/slog"
"math"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -58,6 +59,9 @@ func newNotifications(shardId int64, offset int64, timestamp uint64) *notificati
}

func (n *notifications) Modified(key string, versionId, modificationsCount int64) {
if strings.HasPrefix(key, common.InternalKeyPrefix) {
return
}
nType := proto.NotificationType_KEY_CREATED
if modificationsCount > 0 {
nType = proto.NotificationType_KEY_MODIFIED
Expand All @@ -69,6 +73,9 @@ func (n *notifications) Modified(key string, versionId, modificationsCount int64
}

func (n *notifications) Deleted(key string) {
if strings.HasPrefix(key, common.InternalKeyPrefix) {
return
}
n.batch.Notifications[key] = &proto.Notification{
Type: proto.NotificationType_KEY_DELETED,
}
Expand Down
1 change: 1 addition & 0 deletions server/session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ func (sm *sessionManager) createSession(request *proto.CreateSessionRequest, min

metadata := proto.SessionMetadataFromVTPool()
metadata.TimeoutMs = uint32(timeout.Milliseconds())
metadata.Identity = request.ClientIdentity
defer metadata.ReturnToVTPool()

marshalledMetadata, err := metadata.MarshalVT()
Expand Down
3 changes: 2 additions & 1 deletion server/session_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ package server
import (
"context"
"errors"
"github.com/streamnative/oxia/server/wal"
"io"
"testing"
"time"

"github.com/streamnative/oxia/server/wal"

"github.com/stretchr/testify/assert"
pb "google.golang.org/protobuf/proto"

Expand Down
Loading