Skip to content

Commit

Permalink
Add a quota handler callback
Browse files Browse the repository at this point in the history
  • Loading branch information
rg0now committed Jan 15, 2025
1 parent a4ddb68 commit 28967fa
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 1 deletion.
3 changes: 2 additions & 1 deletion internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ type Request struct {
NonceHash *NonceHash

// User Configuration
AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool)
AuthHandler func(username string, realm string, srcAddr net.Addr) (key []byte, ok bool)
QuotaHandler func(username string, realm string, srcAddr net.Addr) (ok bool)

Log logging.LeveledLogger
Realm string
Expand Down
7 changes: 7 additions & 0 deletions internal/server/turn.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,13 @@ func handleAllocateRequest(req Request, stunMsg *stun.Message) error { //nolint:
// server is free to define this allocation quota any way it wishes,
// but SHOULD define it based on the username used to authenticate
// the request, and not on the client's transport address.
if req.QuotaHandler != nil && !req.QuotaHandler(usernameAttr.String(), realmAttr.String(), req.SrcAddr) {
quotaReachedMsg := buildMsg(stunMsg.TransactionID,
stun.NewType(stun.MethodAllocate, stun.ClassErrorResponse),
&stun.ErrorCodeAttribute{Code: stun.CodeAllocQuotaReached})

return buildAndSend(req.Conn, req.SrcAddr, quotaReachedMsg...)
}

// 8. Also at any point, the server MAY choose to reject the request
// with a 300 (Try Alternate) error if it wishes to redirect the
Expand Down
3 changes: 3 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
type Server struct {
log logging.LeveledLogger
authHandler AuthHandler
quotaHandler QuotaHandler
realm string
channelBindTimeout time.Duration
nonceHash *server.NonceHash
Expand Down Expand Up @@ -59,6 +60,7 @@ func NewServer(config ServerConfig) (*Server, error) { //nolint:gocognit,cyclop
server := &Server{
log: loggerFactory.NewLogger("turn"),
authHandler: config.AuthHandler,
quotaHandler: config.QuotaHandler,
realm: config.Realm,
channelBindTimeout: config.ChannelBindTimeout,
packetConnConfigs: config.PacketConnConfigs,
Expand Down Expand Up @@ -231,6 +233,7 @@ func (s *Server) readLoop(conn net.PacketConn, allocationManager *allocation.Man
Buff: buf[:n],
Log: s.log,
AuthHandler: s.authHandler,
QuotaHandler: s.quotaHandler,
Realm: s.realm,
AllocationManager: allocationManager,
ChannelBindTimeout: s.channelBindTimeout,
Expand Down
9 changes: 9 additions & 0 deletions server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,11 @@ func genericEventHandler(handlers EventHandlers) allocation.EventHandler { //nol
}
}

// QuotaHandler is a callback allows allocations to be rejected when a per-user quota is
// exceeded. If the callback returns true the allocation request is accepted, otherwise it is
// rejected and a 486 (Allocation Quota Reached) error is returned to the user.
type QuotaHandler func(username, realm string, srcAddr net.Addr) (ok bool)

// ServerConfig configures the Pion TURN Server.
type ServerConfig struct {
// PacketConnConfigs and ListenerConfigs are a list of all the turn listeners
Expand All @@ -207,6 +212,10 @@ type ServerConfig struct {
// allowing users to customize Pion TURN with custom behavior
AuthHandler AuthHandler

// AuthHandler is a callback used to handle incoming auth requests, allowing users to
// customize Pion TURN with custom behavior
QuotaHandler QuotaHandler

// EventHandlers is a set of callbacks for tracking allocation lifecycle.
EventHandlers EventHandlers

Expand Down
52 changes: 52 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,58 @@ func TestSTUNOnly(t *testing.T) {
assert.Equal(t, err.Error(), "Allocate error response (error 400: )")
}

func TestQuotaReached(t *testing.T) {
serverAddr, err := net.ResolveUDPAddr("udp4", "0.0.0.0:3478")
assert.NoError(t, err)

serverConn, err := net.ListenPacket(serverAddr.Network(), serverAddr.String())
assert.NoError(t, err)

defer serverConn.Close() //nolint:errcheck

credMap := map[string][]byte{"user": GenerateAuthKey("user", "pion.ly", "pass")}
server, err := NewServer(ServerConfig{
AuthHandler: func(username, _ string, _ net.Addr) (key []byte, ok bool) {
if pw, ok := credMap[username]; ok {
return pw, true
}
return nil, false //nolint:nlreturn
},
QuotaHandler: func(_, _ string, _ net.Addr) (ok bool) { return false },
Realm: "pion.ly",
PacketConnConfigs: []PacketConnConfig{{
PacketConn: serverConn,
RelayAddressGenerator: &RelayAddressGeneratorStatic{
RelayAddress: net.ParseIP("127.0.0.1"),
Address: "0.0.0.0",
},
}},
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
assert.NoError(t, err)

defer server.Close() //nolint:errcheck

conn, err := net.ListenPacket("udp4", "0.0.0.0:0")
assert.NoError(t, err)

client, err := NewClient(&ClientConfig{
Conn: conn,
STUNServerAddr: "127.0.0.1:3478",
TURNServerAddr: "127.0.0.1:3478",
Username: "user",
Password: "pass",
Realm: "pion.ly",
LoggerFactory: logging.NewDefaultLoggerFactory(),
})
assert.NoError(t, err)
assert.NoError(t, client.Listen())
defer client.Close()

_, err = client.Allocate()
assert.Equal(t, err.Error(), "Allocate error response (error 486: )")
}

func RunBenchmarkServer(b *testing.B, clientNum int) { //nolint:cyclop
b.Helper()

Expand Down

0 comments on commit 28967fa

Please sign in to comment.