diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go index 596c31da..271859a1 100644 --- a/proxy/proxy_server.go +++ b/proxy/proxy_server.go @@ -1382,6 +1382,8 @@ func (s *ProxyServer) DeleteSession(id uint64) { func (s *ProxyServer) deleteSessionLocked(id uint64) { if session, found := s.sessions[id]; found { delete(s.sessions, id) + s.sessionsLock.Unlock() + defer s.sessionsLock.Lock() session.Close() statsSessionsCurrent.Dec() } diff --git a/proxy/proxy_server_test.go b/proxy/proxy_server_test.go index 9cb97174..fe2e36af 100644 --- a/proxy/proxy_server_test.go +++ b/proxy/proxy_server_test.go @@ -33,6 +33,7 @@ import ( "net/http/httptest" "os" "strings" + "sync/atomic" "testing" "time" @@ -67,10 +68,12 @@ func WaitForProxyServer(ctx context.Context, t *testing.T, proxy *ProxyServer) { time.Sleep(10 * time.Millisecond) proxy.Stop() for { - proxy.clientsLock.Lock() + proxy.clientsLock.RLock() clients := len(proxy.clients) + proxy.clientsLock.RUnlock() + proxy.sessionsLock.RLock() sessions := len(proxy.sessions) - proxy.clientsLock.Unlock() + proxy.sessionsLock.RUnlock() proxy.remoteConnectionsLock.Lock() remoteConnections := len(proxy.remoteConnections) proxy.remoteConnectionsLock.Unlock() @@ -353,8 +356,11 @@ func TestProxyCreateSession(t *testing.T) { } type HangingTestMCU struct { - t *testing.T - ctx context.Context + t *testing.T + ctx context.Context + creating chan struct{} + created chan struct{} + cancelled atomic.Bool } func NewHangingTestMCU(t *testing.T) *HangingTestMCU { @@ -364,8 +370,10 @@ func NewHangingTestMCU(t *testing.T) *HangingTestMCU { }) return &HangingTestMCU{ - t: t, - ctx: ctx, + t: t, + ctx: ctx, + creating: make(chan struct{}), + created: make(chan struct{}), } } @@ -393,8 +401,14 @@ func (m *HangingTestMCU) NewPublisher(ctx context.Context, listener signaling.Mc ctx2, cancel := context.WithTimeout(m.ctx, testTimeout*2) defer cancel() + m.creating <- struct{}{} + defer func() { + m.created <- struct{}{} + }() + select { case <-ctx.Done(): + m.cancelled.Store(true) return nil, ctx.Err() case <-ctx2.Done(): return nil, errors.New("Should have been cancelled before") @@ -405,8 +419,14 @@ func (m *HangingTestMCU) NewSubscriber(ctx context.Context, listener signaling.M ctx2, cancel := context.WithTimeout(m.ctx, testTimeout*2) defer cancel() + m.creating <- struct{}{} + defer func() { + m.created <- struct{}{} + }() + select { case <-ctx.Done(): + m.cancelled.Store(true) return nil, ctx.Err() case <-ctx2.Done(): return nil, errors.New("Should have been cancelled before") @@ -419,7 +439,8 @@ func TestProxyCancelOnClose(t *testing.T) { require := require.New(t) proxy, key, server := newProxyServerForTest(t) - proxy.mcu = NewHangingTestMCU(t) + mcu := NewHangingTestMCU(t) + proxy.mcu = mcu ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() @@ -436,22 +457,35 @@ func TestProxyCancelOnClose(t *testing.T) { _, err := client.RunUntilLoad(ctx, 0) assert.NoError(err) - require.NoError(client.SendCommand(&signaling.CommandProxyClientMessage{ - Type: "create-publisher", - StreamType: signaling.StreamTypeVideo, + require.NoError(client.WriteJSON(&signaling.ProxyClientMessage{ + Id: "2345", + Type: "command", + Command: &signaling.CommandProxyClientMessage{ + Type: "create-publisher", + StreamType: signaling.StreamTypeVideo, + }, })) // Simulate expired session while request is still being processed. go func() { + <-mcu.creating if session := proxy.GetSession(1); assert.NotNil(session) { session.Close() } }() if message, err := client.RunUntilMessage(ctx); assert.NoError(err) { - if err := checkMessageType(message, "error"); assert.NoError(err) { - assert.Equal("internal_error", message.Error.Code) - assert.Equal(context.Canceled.Error(), message.Error.Message) + if err := checkMessageType(message, "bye"); assert.NoError(err) { + assert.Equal("session_closed", message.Bye.Reason) } } + + if message, err := client.RunUntilMessage(ctx); assert.Error(err) { + assert.True(websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseNoStatusReceived), "expected close error, got %+v", err) + } else { + t.Errorf("expected error, got %+v", message) + } + + <-mcu.created + assert.True(mcu.cancelled.Load()) } diff --git a/proxy/proxy_session.go b/proxy/proxy_session.go index 1fefca69..ed9ac260 100644 --- a/proxy/proxy_session.go +++ b/proxy/proxy_session.go @@ -104,9 +104,24 @@ func (s *ProxySession) MarkUsed() { } func (s *ProxySession) Close() { + prev := s.SetClient(nil) + if prev != nil { + reason := "session_closed" + if s.IsExpired() { + reason = "session_expired" + } + prev.SendMessage(&signaling.ProxyServerMessage{ + Type: "bye", + Bye: &signaling.ByeProxyServerMessage{ + Reason: reason, + }, + }) + } + s.closeFunc() s.clearPublishers() s.clearSubscribers() + s.proxy.DeleteSession(s.Sid()) } func (s *ProxySession) SetClient(client *ProxyClient) *ProxyClient {