From 2650e82661009abf05e1c10209b1a7a3e8da4fd6 Mon Sep 17 00:00:00 2001 From: Yutaka Takeda Date: Sun, 17 Nov 2024 18:17:33 -0800 Subject: [PATCH] User-aware relay address generator Resolves #420 --- internal/allocation/allocation_manager.go | 26 +- .../allocation/allocation_manager_test.go | 82 +++-- internal/allocation/allocation_test.go | 10 +- internal/server/turn.go | 35 +- internal/server/turn_test.go | 24 +- internal/server/util.go | 34 +- internal/server/util_test.go | 309 ++++++++++++++++++ lt_cred.go | 2 +- relay_address_generator_none.go | 5 +- relay_address_generator_range.go | 5 +- relay_address_generator_static.go | 5 +- server.go | 4 +- server_config.go | 5 +- 13 files changed, 463 insertions(+), 83 deletions(-) create mode 100644 internal/server/util_test.go diff --git a/internal/allocation/allocation_manager.go b/internal/allocation/allocation_manager.go index 2b765921..caa4b9bc 100644 --- a/internal/allocation/allocation_manager.go +++ b/internal/allocation/allocation_manager.go @@ -12,11 +12,17 @@ import ( "github.com/pion/logging" ) +// Metadata contains contextual information for TURN server allocation tasks. +type Metadata struct { + Realm string + Username string +} + // ManagerConfig a bag of config params for Manager. type ManagerConfig struct { LeveledLogger logging.LeveledLogger - AllocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) - AllocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error) + AllocatePacketConn func(network string, requestedPort int, metadata Metadata) (net.PacketConn, net.Addr, error) + AllocateConn func(network string, requestedPort int, metadata Metadata) (net.Conn, net.Addr, error) PermissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool } @@ -33,8 +39,8 @@ type Manager struct { allocations map[FiveTupleFingerprint]*Allocation reservations []*reservation - allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) - allocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error) + allocatePacketConn func(network string, requestedPort int, metadata Metadata) (net.PacketConn, net.Addr, error) + allocateConn func(network string, requestedPort int, metadata Metadata) (net.Conn, net.Addr, error) permissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool } @@ -86,7 +92,7 @@ func (m *Manager) Close() error { } // CreateAllocation creates a new allocation and starts relaying -func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration) (*Allocation, error) { +func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketConn, requestedPort int, lifetime time.Duration, metadata Metadata) (*Allocation, error) { switch { case fiveTuple == nil: return nil, errNilFiveTuple @@ -105,7 +111,7 @@ func (m *Manager) CreateAllocation(fiveTuple *FiveTuple, turnSocket net.PacketCo } a := NewAllocation(turnSocket, fiveTuple, m.log) - conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort) + conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort, metadata) if err != nil { return nil, err } @@ -180,9 +186,13 @@ func (m *Manager) GetReservation(reservationToken string) (int, bool) { } // GetRandomEvenPort returns a random un-allocated udp4 port -func (m *Manager) GetRandomEvenPort() (int, error) { +func (m *Manager) GetRandomEvenPort(metadata Metadata) (int, error) { for i := 0; i < 128; i++ { - conn, addr, err := m.allocatePacketConn("udp4", 0) + conn, addr, err := m.allocatePacketConn("udp4", 0, metadata) + if err != nil { + return 0, err + } + if err != nil { return 0, err } diff --git a/internal/allocation/allocation_manager_test.go b/internal/allocation/allocation_manager_test.go index 014d85d3..64ce2bec 100644 --- a/internal/allocation/allocation_manager_test.go +++ b/internal/allocation/allocation_manager_test.go @@ -7,6 +7,7 @@ package allocation import ( + "errors" "io" "math/rand" "net" @@ -19,10 +20,15 @@ import ( "github.com/stretchr/testify/assert" ) +var ( + errUnexpectedTestRealm = errors.New("unexpected test realm") + errUnexpectedTestUsername = errors.New("unexpected user name") +) + func TestManager(t *testing.T) { tt := []struct { name string - f func(*testing.T, net.PacketConn) + f func(*testing.T, net.PacketConn, string, string) }{ {"CreateInvalidAllocation", subTestCreateInvalidAllocation}, {"CreateAllocation", subTestCreateAllocation}, @@ -42,34 +48,36 @@ func TestManager(t *testing.T) { for _, tc := range tt { f := tc.f t.Run(tc.name, func(t *testing.T) { - f(t, turnSocket) + f(t, turnSocket, "test_realm_1", "test_user_1") }) } } // Test invalid Allocation creations -func subTestCreateInvalidAllocation(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestCreateInvalidAllocation(t *testing.T, turnSocket net.PacketConn, realm, username string) { + m, err := newTestManager(realm, username) assert.NoError(t, err) - if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil { + metadata := Metadata{Realm: realm, Username: username} + if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime, metadata); a != nil || err == nil { t.Errorf("Illegally created allocation with nil FiveTuple") } - if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime, metadata); a != nil || err == nil { t.Errorf("Illegally created allocation with nil turnSocket") } - if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0); a != nil || err == nil { + if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0, metadata); a != nil || err == nil { t.Errorf("Illegally created allocation with 0 lifetime") } } // Test valid Allocation creations -func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn, realm, username string) { + m, err := newTestManager(realm, username) assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + metadata := Metadata{Realm: realm, Username: username} + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, metadata); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } @@ -79,26 +87,28 @@ func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn) { } // Test that two allocations can't be created with the same FiveTuple -func subTestCreateAllocationDuplicateFiveTuple(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestCreateAllocationDuplicateFiveTuple(t *testing.T, turnSocket net.PacketConn, realm, username string) { + m, err := newTestManager(realm, username) assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + metadata := Metadata{Realm: realm, Username: username} + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, metadata); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, metadata); a != nil || err == nil { t.Errorf("Was able to create allocation with same FiveTuple twice") } } -func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn, realm, username string) { + m, err := newTestManager(realm, username) assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + metadata := Metadata{Realm: realm, Username: username} + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, metadata); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } @@ -113,17 +123,17 @@ func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn) { } // Test that allocation should be closed if timeout -func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn, realm, username string) { + m, err := newTestManager(realm, username) assert.NoError(t, err) allocations := make([]*Allocation, 5) lifetime := time.Second - + metadata := Metadata{Realm: realm, Username: username} for index := range allocations { fiveTuple := randomFiveTuple() - a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime) + a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime, metadata) if err != nil { t.Errorf("Failed to create allocation with %v", fiveTuple) } @@ -141,15 +151,15 @@ func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { } // Test for manager close -func subTestManagerClose(t *testing.T, turnSocket net.PacketConn) { - m, err := newTestManager() +func subTestManagerClose(t *testing.T, turnSocket net.PacketConn, realm, username string) { + m, err := newTestManager(realm, username) assert.NoError(t, err) allocations := make([]*Allocation, 2) - - a1, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second) + metadata := Metadata{Realm: realm, Username: username} + a1, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second, metadata) allocations[0] = a1 - a2, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute) + a2, _ := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute, metadata) allocations[1] = a2 // Make a1 timeout @@ -174,12 +184,18 @@ func randomFiveTuple() *FiveTuple { } } -func newTestManager() (*Manager, error) { +func newTestManager(expectedRealm, expectedUsername string) (*Manager, error) { loggerFactory := logging.NewDefaultLoggerFactory() config := ManagerConfig{ LeveledLogger: loggerFactory.NewLogger("test"), - AllocatePacketConn: func(string, int) (net.PacketConn, net.Addr, error) { + AllocatePacketConn: func(_ string, _ int, metadata Metadata) (net.PacketConn, net.Addr, error) { + if metadata.Realm != expectedRealm { + return nil, nil, errUnexpectedTestRealm + } + if metadata.Username != expectedUsername { + return nil, nil, errUnexpectedTestUsername + } conn, err := net.ListenPacket("udp4", "0.0.0.0:0") if err != nil { return nil, nil, err @@ -187,8 +203,9 @@ func newTestManager() (*Manager, error) { return conn, conn.LocalAddr(), nil }, - AllocateConn: func(string, int) (net.Conn, net.Addr, error) { return nil, nil, nil }, + AllocateConn: func(string, int, Metadata) (net.Conn, net.Addr, error) { return nil, nil, nil }, } + return NewManager(config) } @@ -197,11 +214,12 @@ func isClose(conn io.Closer) bool { return closeErr != nil && strings.Contains(closeErr.Error(), "use of closed network connection") } -func subTestGetRandomEvenPort(t *testing.T, _ net.PacketConn) { - m, err := newTestManager() +func subTestGetRandomEvenPort(t *testing.T, _ net.PacketConn, realm, username string) { + m, err := newTestManager(realm, username) assert.NoError(t, err) - port, err := m.GetRandomEvenPort() + metadata := Metadata{Realm: realm, Username: username} + port, err := m.GetRandomEvenPort(metadata) assert.NoError(t, err) assert.True(t, port > 0) assert.True(t, port%2 == 0) diff --git a/internal/allocation/allocation_test.go b/internal/allocation/allocation_test.go index 49269d68..528597a3 100644 --- a/internal/allocation/allocation_test.go +++ b/internal/allocation/allocation_test.go @@ -259,9 +259,13 @@ func subTestAllocationClose(t *testing.T) { } func subTestPacketHandler(t *testing.T) { - network := "udp" + const ( + network = "udp" + testRealm = "test_realm_2" + testUsername = "test_user_2" + ) - m, _ := newTestManager() + m, _ := newTestManager(testRealm, testUsername) // TURN server initialization turnSocket, err := net.ListenPacket(network, "127.0.0.1:0") @@ -292,7 +296,7 @@ func subTestPacketHandler(t *testing.T) { a, err := m.CreateAllocation(&FiveTuple{ SrcAddr: clientListener.LocalAddr(), DstAddr: turnSocket.LocalAddr(), - }, turnSocket, 0, proto.DefaultLifetime) + }, turnSocket, 0, proto.DefaultLifetime, Metadata{Realm: testRealm, Username: testUsername}) assert.Nil(t, err, "should succeed") diff --git a/internal/server/turn.go b/internal/server/turn.go index 46e45ecb..81f1c701 100644 --- a/internal/server/turn.go +++ b/internal/server/turn.go @@ -25,10 +25,14 @@ func handleAllocateRequest(r Request, m *stun.Message) error { // mechanism of [https://tools.ietf.org/html/rfc5389#section-10.2.2] // unless the client and server agree to use another mechanism through // some procedure outside the scope of this document. - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodAllocate) - if !hasAuth { + authResult, err := authenticateRequest(r, m, stun.MethodAllocate) + if !authResult.hasAuth { return err } + metadata := allocation.Metadata{ + Realm: authResult.realm, + Username: authResult.username, + } fiveTuple := &allocation.FiveTuple{ SrcAddr: r.SrcAddr, @@ -51,7 +55,7 @@ func handleAllocateRequest(r Request, m *stun.Message) error { return buildAndSendErr(r.Conn, r.SrcAddr, errRelayAlreadyAllocatedForFiveTuple, msg...) } // A retry allocation - msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(attrs, messageIntegrity)...) + msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(attrs, authResult.messageIntegrity)...) return buildAndSend(r.Conn, r.SrcAddr, msg...) } @@ -104,7 +108,7 @@ func handleAllocateRequest(r Request, m *stun.Message) error { var evenPort proto.EvenPort if err = evenPort.GetFrom(m); err == nil { var randomPort int - randomPort, err = r.AllocationManager.GetRandomEvenPort() + randomPort, err = r.AllocationManager.GetRandomEvenPort(metadata) if err != nil { return buildAndSendErr(r.Conn, r.SrcAddr, err, insufficientCapacityMsg...) } @@ -131,7 +135,8 @@ func handleAllocateRequest(r Request, m *stun.Message) error { fiveTuple, r.Conn, requestedPort, - lifetimeDuration) + lifetimeDuration, + metadata) if err != nil { return buildAndSendErr(r.Conn, r.SrcAddr, err, insufficientCapacityMsg...) } @@ -177,7 +182,7 @@ func handleAllocateRequest(r Request, m *stun.Message) error { responseAttrs = append(responseAttrs, proto.ReservationToken([]byte(reservationToken))) } - msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(responseAttrs, messageIntegrity)...) + msg := buildMsg(m.TransactionID, stun.NewType(stun.MethodAllocate, stun.ClassSuccessResponse), append(responseAttrs, authResult.messageIntegrity)...) a.SetResponseCache(m.TransactionID, responseAttrs) return buildAndSend(r.Conn, r.SrcAddr, msg...) } @@ -185,8 +190,8 @@ func handleAllocateRequest(r Request, m *stun.Message) error { func handleRefreshRequest(r Request, m *stun.Message) error { r.Log.Debugf("Received RefreshRequest from %s", r.SrcAddr) - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodRefresh) - if !hasAuth { + authResult, err := authenticateRequest(r, m, stun.MethodRefresh) + if !authResult.hasAuth { return err } @@ -212,7 +217,7 @@ func handleRefreshRequest(r Request, m *stun.Message) error { &proto.Lifetime{ Duration: lifetimeDuration, }, - messageIntegrity, + authResult.messageIntegrity, }...)...) } @@ -228,8 +233,8 @@ func handleCreatePermissionRequest(r Request, m *stun.Message) error { return fmt.Errorf("%w %v:%v", errNoAllocationFound, r.SrcAddr, r.Conn.LocalAddr()) } - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodCreatePermission) - if !hasAuth { + authResult, err := authenticateRequest(r, m, stun.MethodCreatePermission) + if !authResult.hasAuth { return err } @@ -267,7 +272,7 @@ func handleCreatePermissionRequest(r Request, m *stun.Message) error { respClass = stun.ClassErrorResponse } - return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodCreatePermission, respClass), []stun.Setter{messageIntegrity}...)...) + return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodCreatePermission, respClass), []stun.Setter{authResult.messageIntegrity}...)...) } func handleSendIndication(r Request, m *stun.Message) error { @@ -317,8 +322,8 @@ func handleChannelBindRequest(r Request, m *stun.Message) error { badRequestMsg := buildMsg(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: stun.CodeBadRequest}) - messageIntegrity, hasAuth, err := authenticateRequest(r, m, stun.MethodChannelBind) - if !hasAuth { + authResult, err := authenticateRequest(r, m, stun.MethodChannelBind) + if !authResult.hasAuth { return err } @@ -351,7 +356,7 @@ func handleChannelBindRequest(r Request, m *stun.Message) error { return buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } - return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse), []stun.Setter{messageIntegrity}...)...) + return buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(stun.MethodChannelBind, stun.ClassSuccessResponse), []stun.Setter{authResult.messageIntegrity}...)...) } func handleChannelData(r Request, c *proto.ChannelData) error { diff --git a/internal/server/turn_test.go b/internal/server/turn_test.go index e4a3b947..59df7b76 100644 --- a/internal/server/turn_test.go +++ b/internal/server/turn_test.go @@ -7,6 +7,7 @@ package server import ( + "errors" "net" "testing" "time" @@ -18,6 +19,11 @@ import ( "github.com/stretchr/testify/assert" ) +var ( + errUnexpectedTestRealm = errors.New("unexpected test realm") + errUnexpectedTestUsername = errors.New("unexpected user name") +) + func TestAllocationLifeTime(t *testing.T) { t.Run("Parsing", func(t *testing.T) { lifetime := proto.Lifetime{ @@ -63,16 +69,28 @@ func TestAllocationLifeTime(t *testing.T) { logger := logging.NewDefaultLoggerFactory().NewLogger("turn") + const ( + realm = "test" + username = "tester" + ) + allocationManager, err := allocation.NewManager(allocation.ManagerConfig{ - AllocatePacketConn: func(network string, _ int) (net.PacketConn, net.Addr, error) { + AllocatePacketConn: func(network string, _ int, metadata allocation.Metadata) (net.PacketConn, net.Addr, error) { conn, listenErr := net.ListenPacket(network, "0.0.0.0:0") if err != nil { return nil, nil, listenErr } + if metadata.Realm != realm { + return nil, nil, errUnexpectedTestRealm + } + if metadata.Username != username { + return nil, nil, errUnexpectedTestUsername + } + return conn, conn.LocalAddr(), nil }, - AllocateConn: func(string, int) (net.Conn, net.Addr, error) { + AllocateConn: func(string, int, allocation.Metadata) (net.Conn, net.Addr, error) { return nil, nil, nil }, LeveledLogger: logger, @@ -97,7 +115,7 @@ func TestAllocationLifeTime(t *testing.T) { fiveTuple := &allocation.FiveTuple{SrcAddr: r.SrcAddr, DstAddr: r.Conn.LocalAddr(), Protocol: allocation.UDP} - _, err = r.AllocationManager.CreateAllocation(fiveTuple, r.Conn, 0, time.Hour) + _, err = r.AllocationManager.CreateAllocation(fiveTuple, r.Conn, 0, time.Hour, allocation.Metadata{Realm: realm, Username: username}) assert.NoError(t, err) assert.NotNil(t, r.AllocationManager.GetAllocation(fiveTuple)) diff --git a/internal/server/util.go b/internal/server/util.go index 7c01d329..89186572 100644 --- a/internal/server/util.go +++ b/internal/server/util.go @@ -42,14 +42,21 @@ func buildMsg(transactionID [stun.TransactionIDSize]byte, msgType stun.MessageTy return append([]stun.Setter{&stun.Message{TransactionID: transactionID}, msgType}, additional...) } -func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) (stun.MessageIntegrity, bool, error) { - respondWithNonce := func(responseCode stun.ErrorCode) (stun.MessageIntegrity, bool, error) { +type authenticationResult struct { + messageIntegrity stun.MessageIntegrity + username string + realm string + hasAuth bool +} + +func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) (authenticationResult, error) { + respondWithNonce := func(responseCode stun.ErrorCode) (authenticationResult, error) { nonce, err := r.NonceHash.Generate() if err != nil { - return nil, false, err + return authenticationResult{nil, "", "", false}, err } - return nil, false, buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, + return authenticationResult{nil, "", "", false}, buildAndSend(r.Conn, r.SrcAddr, buildMsg(m.TransactionID, stun.NewType(callingMethod, stun.ClassErrorResponse), &stun.ErrorCodeAttribute{Code: responseCode}, stun.NewNonce(nonce), @@ -70,11 +77,11 @@ func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) // Respond with 400 so clients don't retry if r.AuthHandler == nil { sendErr := buildAndSend(r.Conn, r.SrcAddr, badRequestMsg...) - return nil, false, sendErr + return authenticationResult{}, sendErr } if err := nonceAttr.GetFrom(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return authenticationResult{}, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } // Assert Nonce is signed and is not expired @@ -83,21 +90,26 @@ func authenticateRequest(r Request, m *stun.Message, callingMethod stun.Method) } if err := realmAttr.GetFrom(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return authenticationResult{}, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } else if err := usernameAttr.GetFrom(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return authenticationResult{}, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } ourKey, ok := r.AuthHandler(usernameAttr.String(), realmAttr.String(), r.SrcAddr) if !ok { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, fmt.Errorf("%w %s", errNoSuchUser, usernameAttr.String()), badRequestMsg...) + return authenticationResult{}, buildAndSendErr(r.Conn, r.SrcAddr, fmt.Errorf("%w %s", errNoSuchUser, usernameAttr.String()), badRequestMsg...) } if err := stun.MessageIntegrity(ourKey).Check(m); err != nil { - return nil, false, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) + return authenticationResult{}, buildAndSendErr(r.Conn, r.SrcAddr, err, badRequestMsg...) } - return stun.MessageIntegrity(ourKey), true, nil + return authenticationResult{ + messageIntegrity: stun.MessageIntegrity(ourKey), + username: usernameAttr.String(), + realm: realmAttr.String(), + hasAuth: true, + }, nil } func allocationLifeTime(m *stun.Message) time.Duration { diff --git a/internal/server/util_test.go b/internal/server/util_test.go new file mode 100644 index 00000000..d8e306e0 --- /dev/null +++ b/internal/server/util_test.go @@ -0,0 +1,309 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package server + +import ( + "net" + "os" + "testing" + "time" + + "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/proto" + "github.com/stretchr/testify/require" +) + +func TestAuthenticateRequest(t *testing.T) { + const ( + testUsername = "test-user" + testRealm = "test-realm" + ) + testMsgIntegrity := stun.NewLongTermIntegrity(testUsername, testRealm, "pass") + + var conn net.PacketConn + var nonce string + var r *Request + + type options struct { + noAuthHandler bool + } + + setUp := func(t *testing.T, opts options) func() { + var err error + conn, err = net.ListenPacket("udp", "127.0.0.1:0") + require.NoError(t, err) + srcAddr := conn.LocalAddr() + + nonceHash, err := NewNonceHash() + require.NoError(t, err) + nonce, err = nonceHash.Generate() + require.NoError(t, err) + + r = &Request{ + Conn: conn, + SrcAddr: srcAddr, + AuthHandler: func(username, realm string, _ net.Addr) (key []byte, ok bool) { + return testMsgIntegrity, username == testUsername && realm == testRealm + }, + NonceHash: nonceHash, + } + if opts.noAuthHandler { + r.AuthHandler = nil + } + + return func() { + err = conn.Close() + if err != nil { + t.Errorf("failed to close connection: %v", err) + } + } + } + + checkSTUNAllocateErrorResponse := func(t *testing.T) stun.ErrorCode { + // Set read deadline to avoid blocking for a long time + err := conn.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + require.NoError(t, err) + + // Check the error response + buf := make([]byte, 1024) + n, _, err := conn.ReadFrom(buf) + require.NoError(t, err) + + resp := &stun.Message{} + err = resp.UnmarshalBinary(buf[:n]) + require.NoError(t, err) + require.Equal(t, stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse), resp.Type) + var attrErrorCode stun.ErrorCodeAttribute + err = attrErrorCode.GetFrom(resp) + require.NoError(t, err) + return attrErrorCode.Code + } + + checkNoSTUNResponse := func(t *testing.T) { + // Set read deadline to avoid blocking for a long time + err := conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + require.NoError(t, err) + + // Check the error response + buf := make([]byte, 1024) + _, _, err = conn.ReadFrom(buf) + require.ErrorIs(t, err, os.ErrDeadlineExceeded) + } + + t.Run("auth success", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + testMsgIntegrity, + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.NoError(t, err) + require.True(t, authResult.hasAuth) + require.Equal(t, testRealm, authResult.realm, "Realm value should be present in the result") + require.Equal(t, testUsername, authResult.username, "Username value should be present in the result") + + checkNoSTUNResponse(t) + }) + + t.Run("no message integrity", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + // Message integrity attribute is missing + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.NoError(t, err) + require.False(t, authResult.hasAuth) + + // Check the error response + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeUnauthorized, errorCode) + }) + + t.Run("no auth handler", func(t *testing.T) { + tearDown := setUp(t, options{noAuthHandler: true}) + defer tearDown() + + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + testMsgIntegrity, + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.NoError(t, err) + require.False(t, authResult.hasAuth) + + // Check the error response + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) + }) + + t.Run("no nonce", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + // Nonce attribute is missing + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + testMsgIntegrity, + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.ErrorIs(t, err, stun.ErrAttributeNotFound) + require.False(t, authResult.hasAuth) + + // Check the error response + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) + }) + + t.Run("invalid nonce", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + testMsgIntegrity, + stun.NewNonce("bad nonce"), // <- bad nonce + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.NoError(t, err) + require.False(t, authResult.hasAuth) + + // Check the error response + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeStaleNonce, errorCode) + }) + + t.Run("no realm", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + // Realm attribute is missing + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + testMsgIntegrity, + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.ErrorIs(t, err, stun.ErrAttributeNotFound) + require.False(t, authResult.hasAuth) + + // Check the error response + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) + }) + + t.Run("no username", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + // Username attribute is missing + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewRealm(testRealm), + testMsgIntegrity, + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.ErrorIs(t, err, stun.ErrAttributeNotFound) + require.False(t, authResult.hasAuth) + + // Check the error response + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) + }) + + t.Run("unknown username", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername("bad user"), // <- user name that does not exist + stun.NewRealm(testRealm), + testMsgIntegrity, + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.ErrorContains(t, err, "no such user") + require.False(t, authResult.hasAuth) + + // Check the error response + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) + }) + + t.Run("invalid message integrity", func(t *testing.T) { + tearDown := setUp(t, options{}) + defer tearDown() + + m, err := stun.Build( + stun.TransactionID, + stun.NewType(stun.MethodAllocate, stun.ClassRequest), + proto.RequestedTransport{Protocol: proto.ProtoUDP}, + stun.NewUsername(testUsername), + stun.NewRealm(testRealm), + stun.NewLongTermIntegrity(testUsername, testRealm, "bad"), // <- bad message integrity + stun.NewNonce(nonce), + ) + require.NoError(t, err) + + authResult, err := authenticateRequest(*r, m, stun.MethodAllocate) + require.ErrorIs(t, err, stun.ErrIntegrityMismatch) + require.False(t, authResult.hasAuth) + + // Check the error response + errorCode := checkSTUNAllocateErrorResponse(t) + require.Equal(t, stun.CodeBadRequest, errorCode) + }) +} diff --git a/lt_cred.go b/lt_cred.go index 42466c38..bd3197f1 100644 --- a/lt_cred.go +++ b/lt_cred.go @@ -79,7 +79,7 @@ func LongTermTURNRESTAuthHandler(sharedSecret string, l logging.LeveledLogger) A l = logging.NewDefaultLoggerFactory().NewLogger("turn") } return func(username, realm string, srcAddr net.Addr) (key []byte, ok bool) { - l.Tracef("Authentication username=%q realm=%q srcAddr=%v\n", username, realm, srcAddr) + l.Tracef("Authentication username=%q realm=%q srcAddr=%v", username, realm, srcAddr) timestamp := strings.Split(username, ":")[0] t, err := strconv.Atoi(timestamp) if err != nil { diff --git a/relay_address_generator_none.go b/relay_address_generator_none.go index b0974010..da95b0cc 100644 --- a/relay_address_generator_none.go +++ b/relay_address_generator_none.go @@ -10,6 +10,7 @@ import ( "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" + "github.com/pion/turn/v4/internal/allocation" ) // RelayAddressGeneratorNone returns the listener with no modifications @@ -39,7 +40,7 @@ func (r *RelayAddressGeneratorNone) Validate() error { } // AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { +func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requestedPort int, _ allocation.Metadata) (net.PacketConn, net.Addr, error) { conn, err := r.Net.ListenPacket(network, r.Address+":"+strconv.Itoa(requestedPort)) if err != nil { return nil, nil, err @@ -49,6 +50,6 @@ func (r *RelayAddressGeneratorNone) AllocatePacketConn(network string, requested } // AllocateConn generates a new Conn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorNone) AllocateConn(string, int) (net.Conn, net.Addr, error) { +func (r *RelayAddressGeneratorNone) AllocateConn(string, int, allocation.Metadata) (net.Conn, net.Addr, error) { return nil, nil, errTODO } diff --git a/relay_address_generator_range.go b/relay_address_generator_range.go index d87a57f9..235d7669 100644 --- a/relay_address_generator_range.go +++ b/relay_address_generator_range.go @@ -10,6 +10,7 @@ import ( "github.com/pion/randutil" "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" + "github.com/pion/turn/v4/internal/allocation" ) // RelayAddressGeneratorPortRange can be used to only allocate connections inside a defined port range. @@ -68,7 +69,7 @@ func (r *RelayAddressGeneratorPortRange) Validate() error { } // AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorPortRange) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { +func (r *RelayAddressGeneratorPortRange) AllocatePacketConn(network string, requestedPort int, _ allocation.Metadata) (net.PacketConn, net.Addr, error) { if requestedPort != 0 { conn, err := r.Net.ListenPacket(network, fmt.Sprintf("%s:%d", r.Address, requestedPort)) if err != nil { @@ -103,6 +104,6 @@ func (r *RelayAddressGeneratorPortRange) AllocatePacketConn(network string, requ } // AllocateConn generates a new Conn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorPortRange) AllocateConn(string, int) (net.Conn, net.Addr, error) { +func (r *RelayAddressGeneratorPortRange) AllocateConn(string, int, allocation.Metadata) (net.Conn, net.Addr, error) { return nil, nil, errTODO } diff --git a/relay_address_generator_static.go b/relay_address_generator_static.go index 39c68777..bab3e611 100644 --- a/relay_address_generator_static.go +++ b/relay_address_generator_static.go @@ -10,6 +10,7 @@ import ( "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" + "github.com/pion/turn/v4/internal/allocation" ) // RelayAddressGeneratorStatic can be used to return static IP address each time a relay is created. @@ -45,7 +46,7 @@ func (r *RelayAddressGeneratorStatic) Validate() error { } // AllocatePacketConn generates a new PacketConn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorStatic) AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) { +func (r *RelayAddressGeneratorStatic) AllocatePacketConn(network string, requestedPort int, _ allocation.Metadata) (net.PacketConn, net.Addr, error) { conn, err := r.Net.ListenPacket(network, r.Address+":"+strconv.Itoa(requestedPort)) if err != nil { return nil, nil, err @@ -63,6 +64,6 @@ func (r *RelayAddressGeneratorStatic) AllocatePacketConn(network string, request } // AllocateConn generates a new Conn to receive traffic on and the IP/Port to populate the allocation response with -func (r *RelayAddressGeneratorStatic) AllocateConn(string, int) (net.Conn, net.Addr, error) { +func (r *RelayAddressGeneratorStatic) AllocateConn(string, int, allocation.Metadata) (net.Conn, net.Addr, error) { return nil, nil, errTODO } diff --git a/server.go b/server.go index 3b58938f..0123f634 100644 --- a/server.go +++ b/server.go @@ -171,11 +171,11 @@ type nilAddressGenerator struct{} func (n *nilAddressGenerator) Validate() error { return errRelayAddressGeneratorNil } -func (n *nilAddressGenerator) AllocatePacketConn(string, int) (net.PacketConn, net.Addr, error) { +func (n *nilAddressGenerator) AllocatePacketConn(string, int, allocation.Metadata) (net.PacketConn, net.Addr, error) { return nil, nil, errRelayAddressGeneratorNil } -func (n *nilAddressGenerator) AllocateConn(string, int) (net.Conn, net.Addr, error) { +func (n *nilAddressGenerator) AllocateConn(string, int, allocation.Metadata) (net.Conn, net.Addr, error) { return nil, nil, errRelayAddressGeneratorNil } diff --git a/server_config.go b/server_config.go index eab2988e..de438cf4 100644 --- a/server_config.go +++ b/server_config.go @@ -11,6 +11,7 @@ import ( "time" "github.com/pion/logging" + "github.com/pion/turn/v4/internal/allocation" ) // RelayAddressGenerator is used to generate a RelayAddress when creating an allocation. @@ -20,10 +21,10 @@ type RelayAddressGenerator interface { Validate() error // Allocate a PacketConn (UDP) RelayAddress - AllocatePacketConn(network string, requestedPort int) (net.PacketConn, net.Addr, error) + AllocatePacketConn(network string, requestedPort int, metadata allocation.Metadata) (net.PacketConn, net.Addr, error) // Allocate a Conn (TCP) RelayAddress - AllocateConn(network string, requestedPort int) (net.Conn, net.Addr, error) + AllocateConn(network string, requestedPort int, metadata allocation.Metadata) (net.Conn, net.Addr, error) } // PermissionHandler is a callback to filter incoming CreatePermission and ChannelBindRequest