diff --git a/internal/allocation/allocation.go b/internal/allocation/allocation.go index 9e2d5352..c2ec7ece 100644 --- a/internal/allocation/allocation.go +++ b/internal/allocation/allocation.go @@ -35,6 +35,8 @@ type Allocation struct { channelBindings []*ChannelBind lifetimeTimer *time.Timer closed chan interface{} + username, realm string + callback EventHandler log logging.LeveledLogger // Some clients (Firefox or others using resiprocate's nICE lib) may retry allocation @@ -45,12 +47,18 @@ type Allocation struct { } // NewAllocation creates a new instance of NewAllocation. -func NewAllocation(turnSocket net.PacketConn, fiveTuple *FiveTuple, log logging.LeveledLogger) *Allocation { +func NewAllocation( + turnSocket net.PacketConn, + fiveTuple *FiveTuple, + callback EventHandler, + log logging.LeveledLogger, +) *Allocation { return &Allocation{ TurnSocket: turnSocket, fiveTuple: fiveTuple, permissions: make(map[string]*Permission, 64), closed: make(chan interface{}), + callback: callback, log: log, } } @@ -82,6 +90,21 @@ func (a *Allocation) AddPermission(perms *Permission) { a.permissions[fingerprint] = perms a.permissionsLock.Unlock() + if a.callback != nil { + if u, ok := perms.Addr.(*net.UDPAddr); ok { + a.callback(EventHandlerArgs{ + Type: OnPermissionCreated, + SrcAddr: a.fiveTuple.SrcAddr, + DstAddr: a.fiveTuple.DstAddr, + Protocol: a.fiveTuple.Protocol, + Username: a.username, + Realm: a.realm, + RelayAddr: a.RelayAddr, + PeerIP: u.IP, + }) + } + } + perms.start(permissionTimeout) } @@ -90,6 +113,33 @@ func (a *Allocation) RemovePermission(addr net.Addr) { a.permissionsLock.Lock() defer a.permissionsLock.Unlock() delete(a.permissions, ipnet.FingerprintAddr(addr)) + + if a.callback != nil { + if u, ok := addr.(*net.UDPAddr); ok { + a.callback(EventHandlerArgs{ + Type: OnPermissionDeleted, + SrcAddr: a.fiveTuple.SrcAddr, + DstAddr: a.fiveTuple.DstAddr, + Protocol: a.fiveTuple.Protocol, + Username: a.username, + Realm: a.realm, + RelayAddr: a.RelayAddr, + PeerIP: u.IP, + }) + } + } +} + +// ListPermissions returns the permissions associated with an allocation. +func (a *Allocation) ListPermissions() []*Permission { + ps := []*Permission{} + a.permissionsLock.RLock() + defer a.permissionsLock.RUnlock() + for _, p := range a.permissions { + ps = append(ps, p) + } + + return ps } // AddChannelBind adds a new ChannelBind to the allocation, it also updates the @@ -114,6 +164,20 @@ func (a *Allocation) AddChannelBind(chanBind *ChannelBind, lifetime time.Duratio // Channel binds also refresh permissions. a.AddPermission(NewPermission(chanBind.Peer, a.log)) + + if a.callback != nil { + a.callback(EventHandlerArgs{ + Type: OnChannelCreated, + SrcAddr: a.fiveTuple.SrcAddr, + DstAddr: a.fiveTuple.DstAddr, + Protocol: a.fiveTuple.Protocol, + Username: a.username, + Realm: a.realm, + RelayAddr: a.RelayAddr, + PeerAddr: chanBind.Peer, + ChannelNumber: uint16(chanBind.Number), + }) + } } else { channelByNumber.refresh(lifetime) @@ -131,6 +195,20 @@ func (a *Allocation) RemoveChannelBind(number proto.ChannelNumber) bool { for i := len(a.channelBindings) - 1; i >= 0; i-- { if a.channelBindings[i].Number == number { + if a.callback != nil { + a.callback(EventHandlerArgs{ + Type: OnChannelDeleted, + SrcAddr: a.fiveTuple.SrcAddr, + DstAddr: a.fiveTuple.DstAddr, + Protocol: a.fiveTuple.Protocol, + Username: a.username, + Realm: a.realm, + RelayAddr: a.RelayAddr, + PeerAddr: a.channelBindings[i].Peer, + ChannelNumber: uint16(a.channelBindings[i].Number), + }) + } + a.channelBindings = append(a.channelBindings[:i], a.channelBindings[i+1:]...) return true @@ -166,6 +244,16 @@ func (a *Allocation) GetChannelByAddr(addr net.Addr) *ChannelBind { return nil } +// ListChannelBindings returns the channel bindings associated with an allocation. +func (a *Allocation) ListChannelBindings() []*ChannelBind { + cs := []*ChannelBind{} + a.channelBindingsLock.RLock() + defer a.channelBindingsLock.RUnlock() + cs = append(cs, a.channelBindings...) + + return cs +} + // Refresh updates the allocations lifetime. func (a *Allocation) Refresh(lifetime time.Duration) { if !a.lifetimeTimer.Reset(lifetime) { @@ -201,17 +289,15 @@ func (a *Allocation) Close() error { a.lifetimeTimer.Stop() - a.permissionsLock.RLock() - for _, p := range a.permissions { + for _, p := range a.ListPermissions() { + a.RemovePermission(p.Addr) p.lifetimeTimer.Stop() } - a.permissionsLock.RUnlock() - a.channelBindingsLock.RLock() - for _, c := range a.channelBindings { + for _, c := range a.ListChannelBindings() { + a.RemoveChannelBind(c.Number) c.lifetimeTimer.Stop() } - a.channelBindingsLock.RUnlock() return a.RelaySocket.Close() } diff --git a/internal/allocation/allocation_manager.go b/internal/allocation/allocation_manager.go index a3b011f4..c190341b 100644 --- a/internal/allocation/allocation_manager.go +++ b/internal/allocation/allocation_manager.go @@ -18,6 +18,7 @@ type ManagerConfig struct { AllocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) AllocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error) PermissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool + EventHandler EventHandler } type reservation struct { @@ -36,6 +37,7 @@ type Manager struct { allocatePacketConn func(network string, requestedPort int) (net.PacketConn, net.Addr, error) allocateConn func(network string, requestedPort int) (net.Conn, net.Addr, error) permissionHandler func(sourceAddr net.Addr, peerIP net.IP) bool + EventHandler EventHandler } // NewManager creates a new instance of Manager. @@ -55,6 +57,7 @@ func NewManager(config ManagerConfig) (*Manager, error) { allocatePacketConn: config.AllocatePacketConn, allocateConn: config.AllocateConn, permissionHandler: config.PermissionHandler, + EventHandler: config.EventHandler, }, nil } @@ -94,6 +97,7 @@ func (m *Manager) CreateAllocation( turnSocket net.PacketConn, requestedPort int, lifetime time.Duration, + username, realm string, ) (*Allocation, error) { switch { case fiveTuple == nil: @@ -111,7 +115,9 @@ func (m *Manager) CreateAllocation( if alloc := m.GetAllocation(fiveTuple); alloc != nil { return nil, fmt.Errorf("%w: %v", errDupeFiveTuple, fiveTuple) } - alloc := NewAllocation(turnSocket, fiveTuple, m.log) + alloc := NewAllocation(turnSocket, fiveTuple, m.EventHandler, m.log) + alloc.username = username + alloc.realm = realm conn, relayAddr, err := m.allocatePacketConn("udp4", requestedPort) if err != nil { @@ -131,6 +137,19 @@ func (m *Manager) CreateAllocation( m.allocations[fiveTuple.Fingerprint()] = alloc m.lock.Unlock() + if m.EventHandler != nil { + m.EventHandler(EventHandlerArgs{ + Type: OnAllocationCreated, + SrcAddr: fiveTuple.SrcAddr, + DstAddr: fiveTuple.DstAddr, + Protocol: UDP, + Username: username, + Realm: realm, + RelayAddr: relayAddr, + RequestedPort: requestedPort, + }) + } + go alloc.packetHandler(m) return alloc, nil @@ -152,6 +171,17 @@ func (m *Manager) DeleteAllocation(fiveTuple *FiveTuple) { if err := allocation.Close(); err != nil { m.log.Errorf("Failed to close allocation: %v", err) } + + if m.EventHandler != nil { + m.EventHandler(EventHandlerArgs{ + Type: OnAllocationDeleted, + SrcAddr: fiveTuple.SrcAddr, + DstAddr: fiveTuple.DstAddr, + Protocol: UDP, + Username: allocation.username, + Realm: allocation.realm, + }) + } } // CreateReservation stores the reservation for the token+port. diff --git a/internal/allocation/allocation_manager_test.go b/internal/allocation/allocation_manager_test.go index 5a68cda2..502ff55d 100644 --- a/internal/allocation/allocation_manager_test.go +++ b/internal/allocation/allocation_manager_test.go @@ -54,13 +54,13 @@ func subTestCreateInvalidAllocation(t *testing.T, turnSocket net.PacketConn) { m, err := newTestManager() assert.NoError(t, err) - if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(nil, turnSocket, 0, proto.DefaultLifetime, "", ""); a != nil || err == nil { t.Errorf("Illegally created allocation with nil FiveTuple") } - if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(randomFiveTuple(), nil, 0, proto.DefaultLifetime, "", ""); a != nil || err == nil { t.Errorf("Illegally created allocation with nil turnSocket") } - if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0); a != nil || err == nil { + if a, err := m.CreateAllocation(randomFiveTuple(), turnSocket, 0, 0, "", ""); a != nil || err == nil { t.Errorf("Illegally created allocation with 0 lifetime") } } @@ -73,7 +73,7 @@ func subTestCreateAllocation(t *testing.T, turnSocket net.PacketConn) { assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } @@ -90,11 +90,11 @@ func subTestCreateAllocationDuplicateFiveTuple(t *testing.T, turnSocket net.Pack assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } - if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a != nil || err == nil { + if a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, "", ""); a != nil || err == nil { t.Errorf("Was able to create allocation with same FiveTuple twice") } } @@ -106,7 +106,8 @@ func subTestDeleteAllocation(t *testing.T, turnSocket net.PacketConn) { assert.NoError(t, err) fiveTuple := randomFiveTuple() - if a, err := manager.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime); a == nil || err != nil { + if a, err := manager.CreateAllocation(fiveTuple, turnSocket, 0, proto.DefaultLifetime, + "", ""); a == nil || err != nil { t.Errorf("Failed to create allocation %v %v", a, err) } @@ -133,7 +134,7 @@ func subTestAllocationTimeout(t *testing.T, turnSocket net.PacketConn) { for index := range allocations { fiveTuple := randomFiveTuple() - a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime) + a, err := m.CreateAllocation(fiveTuple, turnSocket, 0, lifetime, "", "") if err != nil { t.Errorf("Failed to create allocation with %v", fiveTuple) } @@ -159,9 +160,9 @@ func subTestManagerClose(t *testing.T, turnSocket net.PacketConn) { allocations := make([]*Allocation, 2) - a1, _ := manager.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second) + a1, _ := manager.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Second, "", "") allocations[0] = a1 - a2, _ := manager.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute) + a2, _ := manager.CreateAllocation(randomFiveTuple(), turnSocket, 0, time.Minute, "", "") allocations[1] = a2 // Make a1 timeout diff --git a/internal/allocation/allocation_test.go b/internal/allocation/allocation_test.go index b2c67717..1eca0881 100644 --- a/internal/allocation/allocation_test.go +++ b/internal/allocation/allocation_test.go @@ -48,7 +48,7 @@ func TestAllocation(t *testing.T) { func subTestGetPermission(t *testing.T) { t.Helper() - alloc := NewAllocation(nil, nil, nil) + alloc := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -92,7 +92,7 @@ func subTestGetPermission(t *testing.T) { func subTestAddPermission(t *testing.T) { t.Helper() - alloc := NewAllocation(nil, nil, nil) + alloc := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -113,7 +113,7 @@ func subTestAddPermission(t *testing.T) { func subTestRemovePermission(t *testing.T) { t.Helper() - alloc := NewAllocation(nil, nil, nil) + alloc := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -138,7 +138,7 @@ func subTestRemovePermission(t *testing.T) { func subTestAddChannelBind(t *testing.T) { t.Helper() - alloc := NewAllocation(nil, nil, nil) + alloc := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -164,7 +164,7 @@ func subTestAddChannelBind(t *testing.T) { func subTestGetChannelByNumber(t *testing.T) { t.Helper() - alloc := NewAllocation(nil, nil, nil) + alloc := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -185,7 +185,7 @@ func subTestGetChannelByNumber(t *testing.T) { func subTestGetChannelByAddr(t *testing.T) { t.Helper() - alloc := NewAllocation(nil, nil, nil) + alloc := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -207,7 +207,7 @@ func subTestGetChannelByAddr(t *testing.T) { func subTestRemoveChannelBind(t *testing.T) { t.Helper() - alloc := NewAllocation(nil, nil, nil) + alloc := NewAllocation(nil, nil, nil, nil) addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3478") if err != nil { @@ -230,7 +230,7 @@ func subTestRemoveChannelBind(t *testing.T) { func subTestAllocationRefresh(t *testing.T) { t.Helper() - alloc := NewAllocation(nil, nil, nil) + alloc := NewAllocation(nil, nil, nil, nil) var wg sync.WaitGroup wg.Add(1) @@ -254,7 +254,7 @@ func subTestAllocationClose(t *testing.T) { panic(err) } - alloc := NewAllocation(nil, nil, nil) + alloc := NewAllocation(nil, nil, nil, nil) alloc.RelaySocket = l // Add mock lifetimeTimer alloc.lifetimeTimer = time.AfterFunc(proto.DefaultLifetime, func() {}) @@ -312,7 +312,7 @@ func subTestPacketHandler(t *testing.T) { alloc, err := manager.CreateAllocation(&FiveTuple{ SrcAddr: clientListener.LocalAddr(), DstAddr: turnSocket.LocalAddr(), - }, turnSocket, 0, proto.DefaultLifetime) + }, turnSocket, 0, proto.DefaultLifetime, "", "") assert.Nil(t, err, "should succeed") @@ -379,16 +379,16 @@ func subTestPacketHandler(t *testing.T) { func subTestResponseCache(t *testing.T) { t.Helper() - a := NewAllocation(nil, nil, nil) + alloc := NewAllocation(nil, nil, nil, nil) transactionID := [stun.TransactionIDSize]byte{1, 2, 3} responseAttrs := []stun.Setter{ &proto.Lifetime{ Duration: proto.DefaultLifetime, }, } - a.SetResponseCache(transactionID, responseAttrs) + alloc.SetResponseCache(transactionID, responseAttrs) - cacheID, cacheAttr := a.GetResponseCache() + cacheID, cacheAttr := alloc.GetResponseCache() assert.Equal(t, transactionID, cacheID) assert.Equal(t, responseAttrs, cacheAttr) } diff --git a/internal/allocation/channel_bind_test.go b/internal/allocation/channel_bind_test.go index 30e3034a..4d72e456 100644 --- a/internal/allocation/channel_bind_test.go +++ b/internal/allocation/channel_bind_test.go @@ -42,7 +42,7 @@ func TestChannelBindReset(t *testing.T) { } func newChannelBind(lifetime time.Duration) *ChannelBind { - a := NewAllocation(nil, nil, nil) + a := NewAllocation(nil, nil, nil, nil) addr, _ := net.ResolveUDPAddr("udp", "0.0.0.0:0") c := &ChannelBind{ diff --git a/internal/allocation/event_handler.go b/internal/allocation/event_handler.go new file mode 100644 index 00000000..b5b718c2 --- /dev/null +++ b/internal/allocation/event_handler.go @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: 2023 The Pion community +// SPDX-License-Identifier: MIT + +package allocation + +import ( + "net" +) + +// EventHandlerType is a type for signaling low-level event callbacks to the server. +type EventHandlerType int + +// Event handler types. +const ( + UnknownEvent EventHandlerType = iota + OnAuth + OnAllocationCreated + OnAllocationDeleted + OnAllocationError + OnPermissionCreated + OnPermissionDeleted + OnChannelCreated + OnChannelDeleted +) + +// EventHandlerArgs is a set of arguments passed from the low-level event callbacks to the server. +type EventHandlerArgs struct { + Type EventHandlerType + SrcAddr, DstAddr, RelayAddr, PeerAddr net.Addr + Protocol Protocol + Username, Realm, Method, Message string + Verdict bool + RequestedPort int + PeerIP net.IP + ChannelNumber uint16 +} + +// EventHandler is a callback used by the server to surface allocation lifecycle events. +type EventHandler func(EventHandlerArgs) diff --git a/internal/allocation/five_tuple.go b/internal/allocation/five_tuple.go index 14761611..b9eba872 100644 --- a/internal/allocation/five_tuple.go +++ b/internal/allocation/five_tuple.go @@ -16,6 +16,17 @@ const ( TCP ) +func (p Protocol) String() string { + switch p { + case UDP: + return "UDP" + case TCP: + return "TCP" + default: + return "" + } +} + // FiveTuple is the combination (client IP address and port, server IP // address and port, and transport protocol (currently one of UDP, // TCP, or TLS)) used to communicate between the client and the diff --git a/internal/server/server.go b/internal/server/server.go index 4f5c3148..750cb70f 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -27,7 +27,8 @@ type Request struct { NonceHash *NonceHash // User Configuration - AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool) + AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool) + Log logging.LeveledLogger Realm string ChannelBindTimeout time.Duration diff --git a/internal/server/turn.go b/internal/server/turn.go index d20c5dc5..da2c18e1 100644 --- a/internal/server/turn.go +++ b/internal/server/turn.go @@ -145,6 +145,12 @@ func handleAllocateRequest(req Request, stunMsg *stun.Message) error { //nolint: } } + // Parse realm and username (already checked in authenticateRequest) + realmAttr := &stun.Realm{} + _ = realmAttr.GetFrom(stunMsg) + usernameAttr := &stun.Username{} + _ = usernameAttr.GetFrom(stunMsg) + // 7. At any point, the server MAY choose to reject the request with a // 486 (Allocation Quota Reached) error if it feels the client is // trying to exceed some locally defined allocation quota. The @@ -161,7 +167,10 @@ func handleAllocateRequest(req Request, stunMsg *stun.Message) error { //nolint: fiveTuple, req.Conn, requestedPort, - lifetimeDuration) + lifetimeDuration, + usernameAttr.String(), + realmAttr.String(), + ) if err != nil { return buildAndSendErr(req.Conn, req.SrcAddr, err, insufficientCapacityMsg...) } diff --git a/internal/server/turn_test.go b/internal/server/turn_test.go index 18e0c47d..d3574668 100644 --- a/internal/server/turn_test.go +++ b/internal/server/turn_test.go @@ -97,7 +97,7 @@ func TestAllocationLifeTime(t *testing.T) { fiveTuple := &allocation.FiveTuple{SrcAddr: req.SrcAddr, DstAddr: req.Conn.LocalAddr(), Protocol: allocation.UDP} - _, err = req.AllocationManager.CreateAllocation(fiveTuple, req.Conn, 0, time.Hour) + _, err = req.AllocationManager.CreateAllocation(fiveTuple, req.Conn, 0, time.Hour, "", "") assert.NoError(t, err) assert.NotNil(t, req.AllocationManager.GetAllocation(fiveTuple)) diff --git a/internal/server/util.go b/internal/server/util.go index a16bf439..7df110bf 100644 --- a/internal/server/util.go +++ b/internal/server/util.go @@ -10,6 +10,7 @@ import ( "time" "github.com/pion/stun/v3" + "github.com/pion/turn/v4/internal/allocation" "github.com/pion/turn/v4/internal/proto" ) @@ -84,12 +85,11 @@ func authenticateRequest(req Request, stunMsg *stun.Message, callingMethod stun. // Respond with 400 so clients don't retry. if req.AuthHandler == nil { sendErr := buildAndSend(req.Conn, req.SrcAddr, badRequestMsg...) - - return nil, false, sendErr + return nil, false, sendErr // nolint:nlreturn } if err := nonceAttr.GetFrom(stunMsg); err != nil { - return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) // nolint:nlreturn } // Assert Nonce is signed and is not expired. @@ -114,12 +114,42 @@ func authenticateRequest(req Request, stunMsg *stun.Message, callingMethod stun. } if err := stun.MessageIntegrity(ourKey).Check(stunMsg); err != nil { - return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) + genAuthEvent(req, stunMsg, callingMethod, false) + return nil, false, buildAndSendErr(req.Conn, req.SrcAddr, err, badRequestMsg...) // nolint:nlreturn } + genAuthEvent(req, stunMsg, callingMethod, true) + return stun.MessageIntegrity(ourKey), true, nil } +func genAuthEvent(req Request, stunMsg *stun.Message, callingMethod stun.Method, verdict bool) { + if req.AllocationManager.EventHandler == nil { + return + } + + realmAttr := &stun.Realm{} + if err := realmAttr.GetFrom(stunMsg); err != nil { + return + } + + usernameAttr := &stun.Username{} + if err := usernameAttr.GetFrom(stunMsg); err != nil { + return + } + + req.AllocationManager.EventHandler(allocation.EventHandlerArgs{ + Type: allocation.OnAuth, + SrcAddr: req.SrcAddr, + DstAddr: req.Conn.LocalAddr(), + Protocol: allocation.UDP, + Username: usernameAttr.String(), + Realm: realmAttr.String(), + Method: callingMethod.String(), + Verdict: verdict, + }) +} + func allocationLifeTime(m *stun.Message) time.Duration { lifetimeDuration := proto.DefaultLifetime diff --git a/server.go b/server.go index 57b7a3b8..10cc2e5b 100644 --- a/server.go +++ b/server.go @@ -27,6 +27,7 @@ type Server struct { realm string channelBindTimeout time.Duration nonceHash *server.NonceHash + eventHandlers EventHandlers packetConnConfigs []PacketConnConfig listenerConfigs []ListenerConfig @@ -64,6 +65,7 @@ func NewServer(config ServerConfig) (*Server, error) { //nolint:gocognit,cyclop listenerConfigs: config.ListenerConfigs, nonceHash: nonceHash, inboundMTU: mtu, + eventHandlers: config.EventHandlers, } if server.channelBindTimeout == 0 { @@ -196,6 +198,7 @@ func (s *Server) createAllocationManager( AllocatePacketConn: addrGenerator.AllocatePacketConn, AllocateConn: addrGenerator.AllocateConn, PermissionHandler: handler, + EventHandler: genericEventHandler(s.eventHandlers), LeveledLogger: s.log, }) if err != nil { @@ -233,6 +236,9 @@ func (s *Server) readLoop(conn net.PacketConn, allocationManager *allocation.Man ChannelBindTimeout: s.channelBindTimeout, NonceHash: s.nonceHash, }); err != nil { + if s.eventHandlers.OnAllocationError != nil { + s.eventHandlers.OnAllocationError(addr, conn.LocalAddr(), allocation.UDP.String(), err.Error()) + } s.log.Errorf("Failed to handle datagram: %v", err) } } diff --git a/server_config.go b/server_config.go index 8f276140..94094343 100644 --- a/server_config.go +++ b/server_config.go @@ -11,6 +11,7 @@ import ( "time" "github.com/pion/logging" + "github.com/pion/turn/v4/internal/allocation" ) // RelayAddressGenerator is used to generate a RelayAddress when creating an allocation. @@ -108,6 +109,87 @@ func GenerateAuthKey(username, realm, password string) []byte { return h.Sum(nil) } +// EventHandlers is a set of callbacks that the server will call at certain hook points during an +// allocation's lifecycle. All events are reported with the context that identifies the allocation +// triggering the event (source and destination address, protocol, username and realm used for +// authenticating the allocation). It is OK to handle only a subset of the callbacks. +type EventHandlers struct { + // OnAuth is called after an authentication request has been processed with the TURN method + // triggering the authentication request (either "Allocate", "Refresh" "CreatePermission", + // or "ChannelBind"), and the verdict is the authentication result. + OnAuth func(srcAddr, dstAddr net.Addr, protocol, username, realm string, method string, verdict bool) + // OnAllocationCreated is called after a new allocation has been made. The relayAddr + // argument specifies the relay address and requestedPort is the port requested by the + // client (if any). + OnAllocationCreated func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr net.Addr, requestedPort int) + // OnAllocationDeleted is called after an allocation has been removed. + OnAllocationDeleted func(srcAddr, dstAddr net.Addr, protocol, username, realm string) + // OnAllocationError is called when the readloop hdndling an allocation exits with an + // error with an error message. + OnAllocationError func(srcAddr, dstAddr net.Addr, protocol, message string) + // OnPermissionCreated is called after a new permission has been made to an IP address. + OnPermissionCreated func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr net.Addr, peer net.IP) + // OnPermissionDeleted is called after a permission for a given IP address has been + // removed. + OnPermissionDeleted func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr net.Addr, peer net.IP) + // OnChannelCreated is called after a new channel has been made. The relay address, the + // peer address and the channel number can be used to uniquely identify the channel + // created. + OnChannelCreated func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr, peer net.Addr, channelNumber uint16) + // OnChannelDeleted is called after a channel has been removed from the server. The relay + // address, the peer address and the channel number can be used to uniquely identify the + // channel deleted. + OnChannelDeleted func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr, peer net.Addr, channelNumber uint16) +} + +func genericEventHandler(handlers EventHandlers) allocation.EventHandler { //nolint:cyclop + return func(arg allocation.EventHandlerArgs) { + switch arg.Type { + case allocation.OnAuth: + if handlers.OnAuth != nil { + handlers.OnAuth(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.Method, arg.Verdict) + } + case allocation.OnAllocationCreated: + if handlers.OnAllocationCreated != nil { + handlers.OnAllocationCreated(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.RelayAddr, arg.RequestedPort) + } + case allocation.OnAllocationDeleted: + if handlers.OnAllocationDeleted != nil { + handlers.OnAllocationDeleted(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm) + } + case allocation.OnPermissionCreated: + if handlers.OnPermissionCreated != nil { + handlers.OnPermissionCreated(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.RelayAddr, arg.PeerIP) + } + case allocation.OnPermissionDeleted: + if handlers.OnPermissionDeleted != nil { + handlers.OnPermissionDeleted(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.RelayAddr, arg.PeerIP) + } + case allocation.OnChannelCreated: + if handlers.OnChannelCreated != nil { + handlers.OnChannelCreated(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.RelayAddr, arg.PeerAddr, arg.ChannelNumber) + } + case allocation.OnChannelDeleted: + if handlers.OnChannelDeleted != nil { + handlers.OnChannelDeleted(arg.SrcAddr, arg.DstAddr, arg.Protocol.String(), + arg.Username, arg.Realm, arg.RelayAddr, arg.PeerAddr, arg.ChannelNumber) + } + default: + } + } +} + // ServerConfig configures the Pion TURN Server. type ServerConfig struct { // PacketConnConfigs and ListenerConfigs are a list of all the turn listeners @@ -125,6 +207,9 @@ type ServerConfig struct { // allowing users to customize Pion TURN with custom behavior AuthHandler AuthHandler + // EventHandlers is a set of callbacks for tracking allocation lifecycle. + EventHandlers EventHandlers + // ChannelBindTimeout sets the lifetime of channel binding. Defaults to 10 minutes. ChannelBindTimeout time.Duration diff --git a/server_test.go b/server_test.go index 0833cae2..81dd3d18 100644 --- a/server_test.go +++ b/server_test.go @@ -9,6 +9,7 @@ package turn import ( "fmt" "net" + "sync/atomic" "syscall" "testing" "time" @@ -21,6 +22,13 @@ import ( "github.com/stretchr/testify/assert" ) +const ( + timeout = 200 * time.Millisecond + interval = 50 * time.Millisecond + stunAddr = "1.2.3.4:3478" + turnAddr = "1.2.3.4:3478" +) + func TestServer(t *testing.T) { //nolint:maintidx lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -381,8 +389,15 @@ func (v *VNet) Close() error { return v.wan.Stop() } -func buildVNet() (*VNet, error) { //nolint:cyclop +func buildVNet() (*VNet, error) { + return buildVNetWithServerEventHandlers(nil) +} + +func buildVNetWithServerEventHandlers(handlers *EventHandlers) (*VNet, error) { //nolint:cyclop loggerFactory := logging.NewDefaultLoggerFactory() + if handlers == nil { + handlers = &EventHandlers{} + } // WAN wan, err := vnet.NewRouter(&vnet.RouterConfig{ @@ -451,7 +466,7 @@ func buildVNet() (*VNet, error) { //nolint:cyclop // Start server... credMap := map[string][]byte{"user": GenerateAuthKey("user", "pion.ly", "pass")} - udpListener, err := net0.ListenPacket("udp4", "0.0.0.0:3478") + udpListener, err := net0.ListenPacket("udp4", "1.2.3.4:3478") if err != nil { return nil, err } @@ -464,7 +479,8 @@ func buildVNet() (*VNet, error) { //nolint:cyclop return nil, false }, - Realm: "pion.ly", + Realm: "pion.ly", + EventHandlers: *handlers, PacketConnConfigs: []PacketConnConfig{ { PacketConn: udpListener, @@ -503,7 +519,16 @@ func buildVNet() (*VNet, error) { //nolint:cyclop }, nil } -func TestServerVNet(t *testing.T) { +func expectEvent(ch chan allocation.EventHandlerArgs) (allocation.EventHandlerArgs, bool) { + select { + case res := <-ch: + return res, true + case <-time.After(timeout): + return allocation.EventHandlerArgs{}, false + } +} + +func TestServerVNet(t *testing.T) { //nolint:maintidx lim := test.TimeOut(time.Second * 30) defer lim.Stop() @@ -526,8 +551,6 @@ func TestServerVNet(t *testing.T) { assert.NoError(t, lconn.Close()) }() - stunAddr := "1.2.3.4:3478" - log.Debug("creating a client.") client, err := NewClient(&ClientConfig{ STUNServerAddr: stunAddr, @@ -549,6 +572,450 @@ func TestServerVNet(t *testing.T) { // to the LAN router. assert.True(t, udpAddr.IP.Equal(net.IPv4(5, 6, 7, 8)), "should match") }) + + t.Run("AllocationLifecycle", func(t *testing.T) { + virtNet, err := buildVNet() + assert.NoError(t, err) + defer func() { + assert.NoError(t, virtNet.Close()) + }() + + // Inject an fake event handler so that we can track the succession of callbacks + events := make(chan allocation.EventHandlerArgs, 5) + defer close(events) + assert.Len(t, virtNet.server.allocationManagers, 1) + virtNet.server.allocationManagers[0].EventHandler = func(arg allocation.EventHandlerArgs) { + log.Info(fmt.Sprintf("%#v", arg)) + events <- arg + } + + lconn, err := virtNet.netL0.ListenPacket("udp4", "0.0.0.0:0") + assert.NoError(t, err, "should succeed") + defer func() { + assert.NoError(t, lconn.Close()) + }() + + log.Debug("creating a client.") + client, err := NewClient(&ClientConfig{ + TURNServerAddr: turnAddr, + Conn: lconn, + Username: "user", + Password: "pass", + Realm: "pion.ly", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + assert.NoError(t, client.Listen(), "should succeed") + defer client.Close() + + log.Debug("sending an allocate request.") + relayConn, err := client.Allocate() + assert.NoError(t, err, "should succeed") + + event, ok := expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAuth, event.Type, "should receive an OnAuth event") + udpAddr, ok := event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.Equal(t, "Allocate", event.Method) + assert.True(t, event.Verdict) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAllocationCreated, event.Type, "should receive an OnAllocationCreated event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.Equal(t, 0, event.RequestedPort) + + relayNetAddr := relayConn.LocalAddr() + log.Debugf("relay-address: %s", relayNetAddr.String()) + relayAddr, ok := relayNetAddr.(*net.UDPAddr) + assert.True(t, ok) + udpAddr, ok = event.RelayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddr.IP.Equal(udpAddr.IP)) + assert.Equal(t, relayAddr.Port, udpAddr.Port) + // The transport relay address should have IP address that was assigned to the server. + assert.True(t, udpAddr.IP.Equal(net.IPv4(1, 2, 3, 4)), "should match") + + log.Debug("Sending test packet") + peerAddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.5"), Port: 80} + _, err = relayConn.WriteTo([]byte("test"), peerAddr) + assert.NoError(t, err, "should succeed") + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAuth, event.Type, "should receive an OnAuth event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.Equal(t, "CreatePermission", event.Method) + assert.True(t, event.Verdict) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnPermissionCreated, event.Type, "should receive an OnPermissionCreated event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.True(t, net.ParseIP("1.2.3.5").Equal(event.PeerIP)) + + log.Debug("Forcing the creation of a channel") + _, err = relayConn.WriteTo([]byte("test"), peerAddr) + assert.NoError(t, err, "should succeed") + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAuth, event.Type, "should receive an OnAuth event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.Equal(t, "ChannelBind", event.Method) + assert.True(t, event.Verdict) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnChannelCreated, event.Type, "should receive an OnChannelCreated event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + udpAddr, ok = event.RelayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddr.IP.Equal(udpAddr.IP)) + assert.Equal(t, relayAddr.Port, udpAddr.Port) + + // obtain the channel id + a := virtNet.server.allocationManagers[0].GetAllocation(&allocation.FiveTuple{ + Protocol: allocation.UDP, + SrcAddr: event.SrcAddr, + DstAddr: event.DstAddr, + }) + assert.NotNil(t, a) + channelBind := a.GetChannelByAddr(peerAddr) + assert.NotNil(t, channelBind) + assert.Equal(t, channelBind.Number, proto.ChannelNumber(event.ChannelNumber)) + + log.Debug("Closing relay connection") + assert.NoError(t, relayConn.Close(), "relay conn close should succeed") + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAuth, event.Type, "should receive an OnAuth event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.Equal(t, "Refresh", event.Method) + assert.True(t, event.Verdict) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnPermissionDeleted, event.Type, "should receive an OnPermissionDeleted event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + assert.True(t, net.ParseIP("1.2.3.5").Equal(event.PeerIP)) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnChannelDeleted, event.Type, "should receive an OnChannelDeleted event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + udpAddr, ok = event.RelayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddr.IP.Equal(udpAddr.IP)) + assert.Equal(t, relayAddr.Port, udpAddr.Port) + assert.Equal(t, channelBind.Number, proto.ChannelNumber(event.ChannelNumber)) + + event, ok = expectEvent(events) + assert.True(t, ok, "should receive an event") + assert.Equal(t, allocation.OnAllocationDeleted, event.Type, "should receive an OnAllocationDeleted event") + udpAddr, ok = event.SrcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = event.DstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP, event.Protocol) + assert.Equal(t, "user", event.Username) + assert.Equal(t, "pion.ly", event.Realm) + }) + + checkAllocation := func(srcAddr, dstAddr net.Addr, protocol, username, realm string) { + udpAddr, ok := srcAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("5.6.7.8").Equal(udpAddr.IP)) + udpAddr, ok = dstAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, net.ParseIP("1.2.3.4").Equal(udpAddr.IP)) + assert.Equal(t, allocation.UDP.String(), protocol) + assert.Equal(t, "user", username) + assert.Equal(t, "pion.ly", realm) + } + authEventHandler := func(expectedVerdict bool) (*EventHandlers, *atomic.Int32) { + counter := &atomic.Int32{} + + return &EventHandlers{ + OnAuth: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, method string, verdict bool) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + assert.True(t, method == "Allocate" || method == "Refresh") // close calls refresh with 0 lifetime + assert.Equal(t, expectedVerdict, verdict) + counter.Add(1) + }, + }, counter + } + + t.Run("AuthEventHandlerSuccess", func(t *testing.T) { + authCallback, counter := authEventHandler(true) + v, err := buildVNetWithServerEventHandlers(authCallback) + assert.NoError(t, err) + defer func() { + assert.NoError(t, v.Close()) + }() + + lconn, err := v.netL0.ListenPacket("udp4", "0.0.0.0:0") + assert.NoError(t, err, "should succeed") + defer func() { + assert.NoError(t, lconn.Close()) + }() + + log.Debug("creating a client.") + client, err := NewClient(&ClientConfig{ + TURNServerAddr: turnAddr, + Conn: lconn, + Username: "user", + Password: "pass", + Realm: "pion.ly", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + assert.NoError(t, client.Listen(), "should succeed") + defer client.Close() + + log.Debug("sending an allocate request.") + relayConn, err := client.Allocate() + assert.NoError(t, err, "should succeed") + + log.Debug("Closing relay connection") + assert.NoError(t, relayConn.Close(), "relay conn close should succeed") + + assert.Eventually(t, func() bool { return counter.Load() == 2 }, timeout, interval) + }) + + t.Run("AuthEventHandlerFailure", func(t *testing.T) { + authCallback, counter := authEventHandler(false) + v, err := buildVNetWithServerEventHandlers(authCallback) + assert.NoError(t, err) + defer func() { + assert.NoError(t, v.Close()) + }() + + lconn, err := v.netL0.ListenPacket("udp4", "0.0.0.0:0") + assert.NoError(t, err, "should succeed") + defer func() { + assert.NoError(t, lconn.Close()) + }() + + log.Debug("creating a client.") + client, err := NewClient(&ClientConfig{ + TURNServerAddr: turnAddr, + Conn: lconn, + Username: "user", + Password: "wrong-pass", + Realm: "pion.ly", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + assert.NoError(t, client.Listen(), "should succeed") + defer client.Close() + + log.Debug("sending an allocate request.") + _, err = client.Allocate() + assert.Error(t, err, "should not succeed") + + assert.Eventually(t, func() bool { return counter.Load() == 1 }, timeout, interval) + }) + + t.Run("AllocationEventHandlers", func(t *testing.T) { + peerAddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.5"), Port: 80} + relayAddrIP := net.ParseIP("1.2.3.4") + allocCreated, allocDeleted := &atomic.Int32{}, &atomic.Int32{} + permissionCreated, permissionDeleted := &atomic.Int32{}, &atomic.Int32{} + channelCreated, channelDeleted := &atomic.Int32{}, &atomic.Int32{} + allocCallback := &EventHandlers{ + OnAllocationCreated: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr net.Addr, requestedPort int, + ) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + assert.Equal(t, 0, requestedPort) + udpAddr, ok := relayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddrIP.Equal(udpAddr.IP)) + allocCreated.Add(1) + }, + OnAllocationDeleted: func(srcAddr, dstAddr net.Addr, protocol, username, realm string) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + allocDeleted.Add(1) + }, + OnPermissionCreated: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr net.Addr, peer net.IP, + ) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + udpAddr, ok := relayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddrIP.Equal(udpAddr.IP)) + assert.True(t, net.ParseIP("1.2.3.5").Equal(peer)) + permissionCreated.Add(1) + }, + OnPermissionDeleted: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr net.Addr, peer net.IP, + ) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + udpAddr, ok := relayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddrIP.Equal(udpAddr.IP)) + assert.True(t, net.ParseIP("1.2.3.5").Equal(peer)) + permissionDeleted.Add(1) + }, + OnChannelCreated: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr, peer net.Addr, channelNumber uint16, + ) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + addr, ok := peer.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, addr.IP.Equal(peerAddr.IP)) + assert.Equal(t, peerAddr.Port, addr.Port) + udpAddr, ok := relayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddrIP.Equal(udpAddr.IP)) + assert.NotZero(t, channelNumber) + channelCreated.Add(1) + }, + OnChannelDeleted: func(srcAddr, dstAddr net.Addr, protocol, username, realm string, + relayAddr, peer net.Addr, channelNumber uint16, + ) { + checkAllocation(srcAddr, dstAddr, protocol, username, realm) + addr, ok := peer.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, addr.IP.Equal(peerAddr.IP)) + assert.Equal(t, peerAddr.Port, addr.Port) + udpAddr, ok := relayAddr.(*net.UDPAddr) + assert.True(t, ok) + assert.True(t, relayAddrIP.Equal(udpAddr.IP)) + assert.NotZero(t, channelNumber) + channelDeleted.Add(1) + }, + } + + v, err := buildVNetWithServerEventHandlers(allocCallback) + assert.NoError(t, err) + defer func() { + assert.NoError(t, v.Close()) + }() + + lconn, err := v.netL0.ListenPacket("udp4", "0.0.0.0:0") + assert.NoError(t, err, "should succeed") + defer func() { + assert.NoError(t, lconn.Close()) + }() + + log.Debug("creating a client.") + client, err := NewClient(&ClientConfig{ + TURNServerAddr: turnAddr, + Conn: lconn, + Username: "user", + Password: "pass", + Realm: "pion.ly", + LoggerFactory: loggerFactory, + }) + assert.NoError(t, err, "should succeed") + assert.NoError(t, client.Listen(), "should succeed") + defer client.Close() + + log.Debug("sending an allocate request.") + relayConn, err := client.Allocate() + assert.NoError(t, err, "should succeed") + + assert.Eventually(t, func() bool { return allocCreated.Load() == 1 }, timeout, interval) + + log.Debug("Sending test packet") + _, err = relayConn.WriteTo([]byte("test"), peerAddr) + assert.NoError(t, err, "should succeed") + + assert.Eventually(t, func() bool { return permissionCreated.Load() == 1 }, timeout, interval) + + log.Debug("Forcing the creation of a channel") + _, err = relayConn.WriteTo([]byte("test"), peerAddr) + assert.NoError(t, err, "should succeed") + + assert.Eventually(t, func() bool { return channelCreated.Load() == 1 }, timeout, interval) + + log.Debug("Closing relay connection") + assert.NoError(t, relayConn.Close(), "relay conn close should succeed") + + assert.Eventually(t, func() bool { return permissionDeleted.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return allocCreated.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return allocDeleted.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return permissionCreated.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return permissionDeleted.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return channelCreated.Load() == 1 }, timeout, interval) + assert.Eventually(t, func() bool { return channelDeleted.Load() == 1 }, timeout, interval) + }) } func TestConsumeSingleTURNFrame(t *testing.T) {