diff --git a/config/crowdsec-blocklist-mirror.yaml b/config/crowdsec-blocklist-mirror.yaml index b4d346c..752ae3a 100644 --- a/config/crowdsec-blocklist-mirror.yaml +++ b/config/crowdsec-blocklist-mirror.yaml @@ -20,6 +20,11 @@ blocklists: - ::1 listen_uri: 127.0.0.1:41412 +# listen_socket: /var/run/crowdsec-blocklist-mirror.sock +# trusted_proxies: +# - 127.0.0.1 +# - 127.0.0.1/32 +# trusted_header: X-Forwarded-For tls: cert_file: key_file: diff --git a/pkg/cfg/config.go b/pkg/cfg/config.go index c7d955a..82ce2cc 100644 --- a/pkg/cfg/config.go +++ b/pkg/cfg/config.go @@ -4,10 +4,11 @@ import ( "errors" "fmt" "io" + "net" "os" "strings" - "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" "golang.org/x/exp/slices" "gopkg.in/yaml.v3" @@ -56,14 +57,18 @@ type TLSConfig struct { } type Config struct { - CrowdsecConfig CrowdsecConfig `yaml:"crowdsec_config"` - Blocklists []*BlockListConfig `yaml:"blocklists"` - ListenURI string `yaml:"listen_uri"` - TLS TLSConfig `yaml:"tls"` - Metrics MetricConfig `yaml:"metrics"` - Logging LoggingConfig `yaml:",inline"` - ConfigVersion string `yaml:"config_version"` - EnableAccessLogs bool `yaml:"enable_access_logs"` + CrowdsecConfig CrowdsecConfig `yaml:"crowdsec_config"` + Blocklists []*BlockListConfig `yaml:"blocklists"` + ListenURI string `yaml:"listen_uri"` + ListenSocket string `yaml:"listen_socket"` + TrustedProxies []string `yaml:"trusted_proxies"` + ParsedTrustedProxies []*net.IPNet `yaml:"-"` + TrustedHeader string `yaml:"trusted_header"` + TLS TLSConfig `yaml:"tls"` + Metrics MetricConfig `yaml:"metrics"` + Logging LoggingConfig `yaml:",inline"` + ConfigVersion string `yaml:"config_version"` + EnableAccessLogs bool `yaml:"enable_access_logs"` } func (cfg *Config) ValidateAndSetDefaults() error { @@ -80,19 +85,19 @@ func (cfg *Config) ValidateAndSetDefaults() error { } if cfg.CrowdsecConfig.UpdateFrequency == "" { - logrus.Warn("update_frequency is not provided") + log.Warn("update_frequency is not provided") cfg.CrowdsecConfig.UpdateFrequency = "10s" } if cfg.ConfigVersion == "" { - logrus.Warn("config version is not provided; assuming v1.0") + log.Warn("config version is not provided; assuming v1.0") cfg.ConfigVersion = "v1.0" } - if cfg.ListenURI == "" { - logrus.Warn("listen_uri is not provided ; assuming 127.0.0.1:41412") + if cfg.ListenURI == "" && cfg.ListenSocket == "" { + log.Warn("listen_uri is not provided ; assuming 127.0.0.1:41412") cfg.ListenURI = "127.0.0.1:41412" } @@ -125,9 +130,50 @@ func (cfg *Config) ValidateAndSetDefaults() error { } } + cfg.ParsedTrustedProxies = make([]*net.IPNet, 0, len(cfg.TrustedProxies)) + for _, ip := range cfg.TrustedProxies { + if !strings.Contains(ip, "/") { + log.Debug("no CIDR provided attempting to add /32 or /128; ", ip) + parsedIP := parseIP(ip) + if parsedIP == nil { + return fmt.Errorf("invalid IP address: %s", ip) + } + switch len(parsedIP) { + case net.IPv4len: + ip += "/32" + case net.IPv6len: + ip += "/128" + } + log.Debug("added CIDR to IP: ", ip) + } + _, ipNet, err := net.ParseCIDR(ip) + if err != nil { + return fmt.Errorf("invalid IP address: %s", ip) + } + log.Info("adding trusted proxy: ", ip) + cfg.ParsedTrustedProxies = append(cfg.ParsedTrustedProxies, ipNet) + } + + if cfg.TrustedHeader == "" { + log.Info("trusted_header is not provided; assuming X-Forwarded-For") + cfg.TrustedHeader = "X-Forwarded-For" + } + + if len(cfg.ParsedTrustedProxies) == 0 { + log.Info("no trusted proxies provided so trusted_header is ignored") + } + return nil } +func parseIP(ip string) net.IP { + parsedIP := net.ParseIP(ip) + if ipv4 := parsedIP.To4(); ipv4 != nil { + return ipv4 + } + return parsedIP +} + func MergedConfig(configPath string) ([]byte, error) { patcher := yamlpatch.NewPatcher(configPath, ".local") diff --git a/pkg/server/server.go b/pkg/server/server.go index 1c5e808..a30c78d 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -29,7 +29,7 @@ func RunServer(ctx context.Context, g *errgroup.Group, config cfg.Config) error return err } - http.HandleFunc(blockListCFG.Endpoint, f) + http.HandleFunc(blockListCFG.Endpoint, globalMiddleware(config, f)) log.Infof("serving blocklist in format %s at endpoint %s", blockListCFG.Format, blockListCFG.Endpoint) } @@ -51,16 +51,37 @@ func RunServer(ctx context.Context, g *errgroup.Group, config cfg.Config) error } server := &http.Server{ - Addr: config.ListenURI, Handler: logHandler, } g.Go(func() error { - err := listenAndServe(server, config) - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return err + if config.ListenSocket != "" { + log.Info("listening on unix socket: ", config.ListenSocket) + listener, err := net.Listen("unix", config.ListenSocket) + if err != nil { + return err + } + defer listener.Close() + if err := listenAndServe(server, listener, config); !errors.Is(err, http.ErrServerClosed) { + return err + } } + return nil + }) + g.Go(func() error { + if config.ListenURI != "" { + log.Info("listening on tcp server: ", config.ListenURI) + listener, err := net.Listen("tcp", config.ListenURI) + if err != nil { + return err + } + defer listener.Close() + + if err := listenAndServe(server, listener, config); !errors.Is(err, http.ErrServerClosed) { + return err + } + } return nil }) @@ -73,15 +94,57 @@ func RunServer(ctx context.Context, g *errgroup.Group, config cfg.Config) error return nil } -func listenAndServe(server *http.Server, config cfg.Config) error { - if config.TLS.CertFile != "" && config.TLS.KeyFile != "" { - log.Infof("Starting server with TLS at %s", config.ListenURI) - return server.ListenAndServeTLS(config.TLS.CertFile, config.TLS.KeyFile) +/* +Global middlewares are middlewares that are applied to all routes and are not specific to a blocklist. +*/ +func globalMiddleware(config cfg.Config, next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + //Parsed unix socket request + if r.RemoteAddr == "@" { + r.RemoteAddr = "127.0.0.1:65535" + } + //Trusted proxies + header := r.Header.Get(config.TrustedHeader) + // If there is no header then we don't need to do anything + if header != "" { + headerSplit := strings.Split(header, ",") + ip, _, err := net.SplitHostPort(strings.TrimSpace(r.RemoteAddr)) + if err != nil { + log.Errorf("error while spliting hostport for %s: %v", r.RemoteAddr, err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + //Loop over the parsed trusted proxies + for _, trustedProxy := range config.ParsedTrustedProxies { + //check if the remote address is in the trusted proxies + if trustedProxy.Contains(net.ParseIP(ip)) { + // Loop over the header values in reverse order + for i := len(headerSplit) - 1; i >= 0; i-- { + ipStr := strings.TrimSpace(headerSplit[i]) + ip := net.ParseIP(ipStr) + if ip == nil { + break + } + // If the IP is not in the trusted proxies, set the remote address to the IP + if (i == 0) || (!trustedProxy.Contains(ip)) { + r.RemoteAddr = ipStr + break + } + } + } + } + } + + next.ServeHTTP(w, r) } +} - log.Infof("Starting server at %s", config.ListenURI) +func listenAndServe(server *http.Server, listener net.Listener, config cfg.Config) error { + if config.TLS.CertFile != "" && config.TLS.KeyFile != "" { + return server.ServeTLS(listener, config.TLS.CertFile, config.TLS.KeyFile) + } - return server.ListenAndServe() + return server.Serve(listener) } var RouteHits = prometheus.NewCounterVec( @@ -132,7 +195,7 @@ func toValidCIDR(ip string) string { } func getTrustedIPs(ips []string) ([]net.IPNet, error) { - trustedIPs := make([]net.IPNet, 0) + trustedIPs := make([]net.IPNet, 0, len(ips)) for _, ip := range ips { cidr := toValidCIDR(ip) @@ -183,36 +246,37 @@ func decisionMiddleware(next http.HandlerFunc) func(w http.ResponseWriter, r *ht func authMiddleware(blockListCfg *cfg.BlockListConfig, next http.HandlerFunc) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - log.Errorf("error while spliting hostport for %s: %v", r.RemoteAddr, err) - http.Error(w, "internal error", http.StatusInternalServerError) - - return - } - - trustedIPs, err := getTrustedIPs(blockListCfg.Authentication.TrustedIPs) - if err != nil { - log.Errorf("error while parsing trusted IPs: %v", err) - http.Error(w, "internal error", http.StatusInternalServerError) + authType := strings.ToLower(blockListCfg.Authentication.Type) + + // If auth != none then we implement checks if not bypass them to the next handler + if authType != "none" { + ip, _, err := net.SplitHostPort(r.RemoteAddr) + // If we can't parse the IP, we use the remote address as is as it most likely been set by the trusted proxies middleware + if err != nil { + ip = r.RemoteAddr + } - return - } + trustedIPs, err := getTrustedIPs(blockListCfg.Authentication.TrustedIPs) + if err != nil { + log.Errorf("error while parsing trusted IPs: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) - switch strings.ToLower(blockListCfg.Authentication.Type) { - case "ip_based": - if !networksContainIP(trustedIPs, ip) { - http.Error(w, "access denied", http.StatusForbidden) return } - case "basic": - if !satisfiesBasicAuth(r, blockListCfg.Authentication.User, blockListCfg.Authentication.Password) { - http.Error(w, "access denied", http.StatusForbidden) - return + + switch authType { + case "ip_based": + if !networksContainIP(trustedIPs, ip) { + http.Error(w, "access denied", http.StatusForbidden) + return + } + case "basic": + if !satisfiesBasicAuth(r, blockListCfg.Authentication.User, blockListCfg.Authentication.Password) { + http.Error(w, "access denied", http.StatusForbidden) + return + } } - case "", "none": } - next.ServeHTTP(w, r) } } diff --git a/test/bouncer/test_tls.py b/test/bouncer/test_tls.py index 1493dd4..8786b8f 100644 --- a/test/bouncer/test_tls.py +++ b/test/bouncer/test_tls.py @@ -42,7 +42,7 @@ def test_tls_server(crowdsec, certs_dir, api_key_factory, bouncer, bm_cfg_factor with bouncer(cfg) as bm: bm.wait_for_lines_fnmatch([ "*Using API key auth*", - "*Starting server at 127.0.0.1:*" + "*listening on tcp server: 127.0.0.1:*" ]) @@ -94,7 +94,7 @@ def test_tls_mutual(crowdsec, certs_dir, bouncer, bm_cfg_factory, bouncer_under_ "*Starting crowdsec-blocklist-mirror*", "*Using CA cert*", "*Using cert auth with cert * and key *", - "*Starting server at 127.0.0.1:*" + "*listening on tcp server: 127.0.0.1:*" ]) # check that the bouncer is registered