From 29cfc7ae2a0bbd5ec3205eae3f6f810519787f26 Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Tue, 20 Jun 2023 20:42:59 +0300 Subject: [PATCH] all: imp err handling --- internal/filtering/blocked.go | 15 +++++++++++---- internal/filtering/filtering.go | 12 +++--------- internal/home/clients.go | 28 ++++++++++++++++++++++------ internal/home/clients_test.go | 13 +++++++------ internal/home/home.go | 6 +++++- 5 files changed, 48 insertions(+), 26 deletions(-) diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index 88686af106f..1d3be758ecd 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -2,6 +2,7 @@ package filtering import ( "encoding/json" + "fmt" "net/http" "time" @@ -63,11 +64,17 @@ func (s *BlockedServices) Clone() (c *BlockedServices) { } } -// BlockedSvcKnown returns true if a blocked service ID is known. -func BlockedSvcKnown(s string) (ok bool) { - _, ok = serviceRules[s] +// Validate returns an error if blocked services contain unknown service ID. s +// must not be nil. +func (s *BlockedServices) Validate() (err error) { + for _, id := range s.IDs { + _, ok := serviceRules[id] + if !ok { + return fmt.Errorf("unknown blocked-service %q", id) + } + } - return ok + return nil } // ApplyBlockedServices - set blocked services settings for this DNS request diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index ea6d4bfbde1..7cad6c9939b 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -988,17 +988,11 @@ func New(c *Config, blockFilters []Filter) (d *DNSFilter, err error) { } if d.BlockedServices != nil { - bsvcs := []string{} - for _, s := range d.BlockedServices.IDs { - if !BlockedSvcKnown(s) { - log.Debug("skipping unknown blocked-service %q", s) + err = d.BlockedServices.Validate() - continue - } - - bsvcs = append(bsvcs, s) + if err != nil { + return nil, fmt.Errorf("filtering: %w", err) } - d.BlockedServices.IDs = bsvcs } if blockFilters != nil { diff --git a/internal/home/clients.go b/internal/home/clients.go index 3bf5db83d16..acac0a1850a 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -77,7 +77,7 @@ func (clients *clientsContainer) Init( etcHosts *aghnet.HostsContainer, arpdb aghnet.ARPDB, filteringConf *filtering.Config, -) { +) (err error) { if clients.list != nil { log.Fatal("clients.list != nil") } @@ -91,13 +91,17 @@ func (clients *clientsContainer) Init( clients.dhcpServer = dhcpServer clients.etcHosts = etcHosts clients.arpdb = arpdb - clients.addFromConfig(objects, filteringConf) + err = clients.addFromConfig(objects, filteringConf) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime) if clients.testing { - return + return nil } clients.updateFromDHCP(true) @@ -108,6 +112,8 @@ func (clients *clientsContainer) Init( if clients.etcHosts != nil { go clients.handleHostsUpdates() } + + return nil } func (clients *clientsContainer) handleHostsUpdates() { @@ -168,7 +174,10 @@ type clientObject struct { // addFromConfig initializes the clients container with objects from the // configuration file. -func (clients *clientsContainer) addFromConfig(objects []*clientObject, filteringConf *filtering.Config) { +func (clients *clientsContainer) addFromConfig( + objects []*clientObject, + filteringConf *filtering.Config, +) (err error) { for _, o := range objects { cli := &Client{ Name: o.Name, @@ -189,7 +198,7 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin if o.SafeSearchConf.Enabled { o.SafeSearchConf.CustomResolver = safeSearchResolver{} - err := cli.setSafeSearch( + err = cli.setSafeSearch( o.SafeSearchConf, filteringConf.SafeSearchCacheSize, time.Minute*time.Duration(filteringConf.CacheTime), @@ -201,6 +210,11 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin } } + err = o.BlockedServices.Validate() + if err != nil { + return fmt.Errorf("clients: %w", err) + } + cli.BlockedServices = o.BlockedServices.Clone() for _, t := range o.Tags { @@ -213,11 +227,13 @@ func (clients *clientsContainer) addFromConfig(objects []*clientObject, filterin slices.Sort(cli.Tags) - _, err := clients.Add(cli) + _, err = clients.Add(cli) if err != nil { log.Error("clients: adding clients %s: %s", cli.Name, err) } } + + return nil } // forConfig returns all currently known persistent clients as objects for the diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index 8361528a256..b2e70e8fc36 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -16,18 +16,19 @@ import ( // newClientsContainer is a helper that creates a new clients container for // tests. -func newClientsContainer() (c *clientsContainer) { +func newClientsContainer(t *testing.T) (c *clientsContainer) { c = &clientsContainer{ testing: true, } - c.Init(nil, nil, nil, nil, &filtering.Config{}) + err := c.Init(nil, nil, nil, nil, &filtering.Config{}) + require.NoError(t, err) return c } func TestClients(t *testing.T) { - clients := newClientsContainer() + clients := newClientsContainer(t) t.Run("add_success", func(t *testing.T) { var ( @@ -198,7 +199,7 @@ func TestClients(t *testing.T) { } func TestClientsWHOIS(t *testing.T) { - clients := newClientsContainer() + clients := newClientsContainer(t) whois := &RuntimeClientWHOISInfo{ Country: "AU", Orgname: "Example Org", @@ -244,7 +245,7 @@ func TestClientsWHOIS(t *testing.T) { } func TestClientsAddExisting(t *testing.T) { - clients := newClientsContainer() + clients := newClientsContainer(t) t.Run("simple", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") @@ -316,7 +317,7 @@ func TestClientsAddExisting(t *testing.T) { } func TestClientsCustomUpstream(t *testing.T) { - clients := newClientsContainer() + clients := newClientsContainer(t) // Add client with upstreams. ok, err := clients.Add(&Client{ diff --git a/internal/home/home.go b/internal/home/home.go index 5f1dd6f2044..64a1a9790a0 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -353,13 +353,17 @@ func initContextClients() (err error) { arpdb = aghnet.NewARPDB() } - Context.clients.Init( + err = Context.clients.Init( config.Clients.Persistent, Context.dhcpServer, Context.etcHosts, arpdb, config.DNS.DnsfilterConf, ) + if err != nil { + // Don't wrap the error, because it's informative enough as is. + return err + } return nil }