diff --git a/docs/env_config.md b/docs/env_config.md index 04032a3..9dd99fa 100644 --- a/docs/env_config.md +++ b/docs/env_config.md @@ -15,7 +15,7 @@ RTCD_RTC_ICEPORTUDP Integer RTCD_RTC_ICEADDRESSTCP String RTCD_RTC_ICEPORTTCP Integer RTCD_RTC_ICEHOSTOVERRIDE String -RTCD_RTC_ICEHOSTPORTOVERRIDE Integer +RTCD_RTC_ICEHOSTPORTOVERRIDE ICEHostPortOverride RTCD_RTC_ICESERVERS Comma-separated list of RTCD_RTC_TURNCONFIG_STATICAUTHSECRET String RTCD_RTC_TURNCONFIG_CREDENTIALSEXPIRATIONMINUTES Integer diff --git a/service/rtc/config.go b/service/rtc/config.go index 599a4e6..63c9866 100644 --- a/service/rtc/config.go +++ b/service/rtc/config.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "net" + "strconv" "strings" ) @@ -24,7 +25,7 @@ type ServerConfig struct { ICEHostOverride string `toml:"ice_host_override"` // ICEHostPortOverride optionally specifies a port number to override the one // used to listen on when sharing host candidates. - ICEHostPortOverride int `toml:"ice_host_port_override"` + ICEHostPortOverride ICEHostPortOverride `toml:"ice_host_port_override"` // A list of ICE server (STUN/TURN) configurations to use. ICEServers ICEServers `toml:"ice_servers"` TURNConfig TURNConfig `toml:"turn"` @@ -57,8 +58,8 @@ func (c ServerConfig) IsValid() error { return fmt.Errorf("invalid TURNConfig: %w", err) } - if c.ICEHostPortOverride != 0 && (c.ICEHostPortOverride < 80 || c.ICEHostPortOverride > 49151) { - return fmt.Errorf("invalid ICEHostPortOverride value: %d is not in allowed range [80, 49151]", c.ICEHostPortOverride) + if err := c.ICEHostPortOverride.IsValid(); err != nil { + return fmt.Errorf("invalid ICEHostPortOverride value: %w", err) } return nil @@ -156,8 +157,6 @@ func (s ICEServers) getSTUN() string { } func (s *ICEServers) Decode(value string) error { - fmt.Println(value) - var urls []string err := json.Unmarshal([]byte(value), &urls) if err == nil { @@ -206,3 +205,90 @@ func (s *ICEServers) UnmarshalTOML(data interface{}) error { return nil } + +type ICEHostPortOverride string + +func (s *ICEHostPortOverride) SinglePort() int { + if s == nil { + return 0 + } + p, _ := strconv.Atoi(string(*s)) + return p +} + +func (s *ICEHostPortOverride) ParseMap() (map[string]int, error) { + if s == nil { + return nil, fmt.Errorf("should not be nil") + } + + if *s == "" { + return nil, nil + } + + pairs := strings.Split(string(*s), ",") + + m := make(map[string]int, len(pairs)) + ports := make(map[int]bool, len(pairs)) + + for _, p := range pairs { + pair := strings.Split(p, "/") + if len(pair) != 2 { + return nil, fmt.Errorf("invalid map pairing syntax") + } + + port, err := strconv.Atoi(pair[1]) + if err != nil { + return nil, fmt.Errorf("failed to parse port number: %w", err) + } + + if _, ok := m[pair[0]]; ok { + return nil, fmt.Errorf("duplicate mapping found for %s", pair[0]) + } + + if ports[port] { + return nil, fmt.Errorf("duplicate port found for %d", port) + } + + m[pair[0]] = port + ports[port] = true + } + + return m, nil +} + +func (s *ICEHostPortOverride) IsValid() error { + if s == nil { + return fmt.Errorf("should not be nil") + } + + if *s == "" { + return nil + } + + if port := s.SinglePort(); port != 0 { + if port < 80 || port > 49151 { + return fmt.Errorf("%d is not in allowed range [80, 49151]", port) + } + return nil + } + + if _, err := s.ParseMap(); err != nil { + return fmt.Errorf("failed to parse mapping: %w", err) + } + + return nil +} + +func (s *ICEHostPortOverride) UnmarshalTOML(data interface{}) error { + switch t := data.(type) { + case string: + *s = ICEHostPortOverride(data.(string)) + return nil + case int, int32, int64: + *s = ICEHostPortOverride(fmt.Sprintf("%v", data)) + default: + return fmt.Errorf("unknown type %T", t) + } + + return nil +} diff --git a/service/rtc/config_test.go b/service/rtc/config_test.go index d8722e4..2ed8d5b 100644 --- a/service/rtc/config_test.go +++ b/service/rtc/config_test.go @@ -70,17 +70,29 @@ func TestServerConfigIsValid(t *testing.T) { }) t.Run("invalid ICEHostPortOverride", func(t *testing.T) { - var cfg ServerConfig - cfg.ICEPortUDP = 8443 - cfg.ICEPortTCP = 8443 - cfg.ICEHostPortOverride = 45 - err := cfg.IsValid() - require.Error(t, err) - require.Equal(t, "invalid ICEHostPortOverride value: 45 is not in allowed range [80, 49151]", err.Error()) - cfg.ICEHostPortOverride = 65000 - err = cfg.IsValid() - require.Error(t, err) - require.Equal(t, "invalid ICEHostPortOverride value: 65000 is not in allowed range [80, 49151]", err.Error()) + t.Run("single port", func(t *testing.T) { + var cfg ServerConfig + cfg.ICEPortUDP = 8443 + cfg.ICEPortTCP = 8443 + cfg.ICEHostPortOverride = "45" + err := cfg.IsValid() + require.Error(t, err) + require.Equal(t, "invalid ICEHostPortOverride value: 45 is not in allowed range [80, 49151]", err.Error()) + cfg.ICEHostPortOverride = "65000" + err = cfg.IsValid() + require.Error(t, err) + require.Equal(t, "invalid ICEHostPortOverride value: 65000 is not in allowed range [80, 49151]", err.Error()) + }) + + t.Run("mapping", func(t *testing.T) { + var cfg ServerConfig + cfg.ICEPortUDP = 8443 + cfg.ICEPortTCP = 8443 + cfg.ICEHostPortOverride = "127.0.0.1,8443" + err := cfg.IsValid() + require.Error(t, err) + require.Equal(t, "invalid ICEHostPortOverride value: failed to parse mapping: invalid map pairing syntax", err.Error()) + }) }) t.Run("valid", func(t *testing.T) { @@ -262,3 +274,44 @@ func TestICEServerConfigIsValid(t *testing.T) { require.NoError(t, err) }) } + +func TestICEHostPortOverrideParseMap(t *testing.T) { + t.Run("nil", func(t *testing.T) { + var override *ICEHostPortOverride + m, err := override.ParseMap() + require.EqualError(t, err, "should not be nil") + require.Nil(t, m) + }) + + t.Run("empty", func(t *testing.T) { + var override ICEHostPortOverride + m, err := override.ParseMap() + require.NoError(t, err) + require.Nil(t, m) + }) + + t.Run("duplicate addresses", func(t *testing.T) { + override := ICEHostPortOverride("127.0.0.1/8444,127.0.0.1/8445") + m, err := override.ParseMap() + require.EqualError(t, err, "duplicate mapping found for 127.0.0.1") + require.Nil(t, m) + }) + + t.Run("duplicate ports", func(t *testing.T) { + override := ICEHostPortOverride("127.0.0.1/8444,127.0.0.2/8444") + m, err := override.ParseMap() + require.EqualError(t, err, "duplicate port found for 8444") + require.Nil(t, m) + }) + + t.Run("valid mapping", func(t *testing.T) { + override := ICEHostPortOverride("127.0.0.1/8443,127.0.0.2/8445,127.0.0.3/8444") + m, err := override.ParseMap() + require.NoError(t, err) + require.Equal(t, map[string]int{ + "127.0.0.1": 8443, + "127.0.0.2": 8445, + "127.0.0.3": 8444, + }, m) + }) +} diff --git a/service/rtc/server.go b/service/rtc/server.go index 171e7d0..cdb9e26 100644 --- a/service/rtc/server.go +++ b/service/rtc/server.go @@ -105,6 +105,19 @@ func (s *Server) Start() error { s.log.Debug("rtc: found local IPs", mlog.Any("ips", s.localIPs)) + if m, _ := s.cfg.ICEHostPortOverride.ParseMap(); len(m) > 0 { + s.log.Debug("rtc: found ice host port override mappings", mlog.Any("mappings", s.cfg.ICEHostPortOverride)) + + for _, ip := range localIPs { + if port, ok := m[ip.String()]; ok { + s.log.Debug("rtc: found port override for local address", mlog.String("address", ip.String()), mlog.Int("port", port)) + s.cfg.ICEHostPortOverride = ICEHostPortOverride(fmt.Sprintf("%d", port)) + // NOTE: currently not supporting multiple ip/port mappings for the same rtcd instance. + break + } + } + } + // Populate public IP addresses map if override is not set and STUN is provided. if s.cfg.ICEHostOverride == "" && len(s.cfg.ICEServers) > 0 { for _, ip := range localIPs { diff --git a/service/rtc/sfu.go b/service/rtc/sfu.go index 91a5de7..afe3223 100644 --- a/service/rtc/sfu.go +++ b/service/rtc/sfu.go @@ -250,14 +250,17 @@ func (s *Server) InitSession(cfg SessionConfig, closeCb func() error) error { return } - if s.cfg.ICEHostPortOverride != 0 && candidate.Typ == webrtc.ICECandidateTypeHost { - s.log.Debug("overriding host candidate port", - mlog.String("sessionID", cfg.SessionID), - mlog.Uint("port", candidate.Port), - mlog.Int("override", s.cfg.ICEHostPortOverride), - mlog.String("addr", candidate.Address), - mlog.Int("protocol", candidate.Protocol)) - candidate.Port = uint16(s.cfg.ICEHostPortOverride) + if port := s.cfg.ICEHostPortOverride.SinglePort(); port != 0 && candidate.Typ == webrtc.ICECandidateTypeHost { + m := getExternalAddrMapFromHostOverride(s.cfg.ICEHostOverride) + if m[candidate.Address] { + s.log.Debug("overriding host candidate port", + mlog.String("sessionID", cfg.SessionID), + mlog.Uint("port", candidate.Port), + mlog.Int("override", port), + mlog.String("addr", candidate.Address), + mlog.Int("protocol", candidate.Protocol)) + candidate.Port = uint16(port) + } } msg, err := newICEMessage(us, candidate) diff --git a/service/rtc/utils.go b/service/rtc/utils.go index ea4916d..2c2ceeb 100644 --- a/service/rtc/utils.go +++ b/service/rtc/utils.go @@ -127,3 +127,19 @@ func generateAddrsPairs(localIPs []netip.Addr, publicAddrsMap map[netip.Addr]str return pairs, nil } + +func getExternalAddrMapFromHostOverride(override string) map[string]bool { + if override == "" { + return nil + } + + pairs := strings.Split(override, ",") + m := make(map[string]bool, len(pairs)) + + for _, p := range pairs { + pair := strings.Split(p, "/") + m[pair[0]] = true + } + + return m +} diff --git a/service/rtc/utils_test.go b/service/rtc/utils_test.go index 64fa274..2824fdf 100644 --- a/service/rtc/utils_test.go +++ b/service/rtc/utils_test.go @@ -154,3 +154,26 @@ func TestIsValidTrackID(t *testing.T) { }) } } + +func TestGetExternalAddrMapFromHostOverride(t *testing.T) { + t.Run("empty", func(t *testing.T) { + m := getExternalAddrMapFromHostOverride("") + require.Empty(t, m) + }) + + t.Run("single host", func(t *testing.T) { + m := getExternalAddrMapFromHostOverride("10.0.0.1") + require.Equal(t, map[string]bool{ + "10.0.0.1": true, + }, m) + }) + + t.Run("mapping", func(t *testing.T) { + m := getExternalAddrMapFromHostOverride("10.0.0.1/127.0.0.1,10.0.0.3/127.0.0.2,10.0.0.2/127.0.0.3") + require.Equal(t, map[string]bool{ + "10.0.0.1": true, + "10.0.0.2": true, + "10.0.0.3": true, + }, m) + }) +}