From aedc14dda9a5164dccd065693ab9a5e9864e2736 Mon Sep 17 00:00:00 2001 From: Ryan Castellucci Date: Sun, 15 Sep 2024 18:34:00 +0100 Subject: [PATCH] add exception list feature --- api.go | 33 +++++++++++++++++++++++++++++++++ handler.go | 8 ++++---- main.go | 15 +++++++++------ server.go | 3 ++- updater.go | 38 ++++++++++++++++++++++++-------------- 5 files changed, 72 insertions(+), 25 deletions(-) diff --git a/api.go b/api.go index aa8a48d..84d2cf0 100644 --- a/api.go +++ b/api.go @@ -36,6 +36,7 @@ func isRunningInDockerContainer() bool { func StartAPIServer(config *Config, reloadChan chan bool, blockCache *MemoryBlockCache, + exceptCache *MemoryBlockCache, questionCache *MemoryQuestionCache) (*http.Server, error) { if !config.APIDebug { gin.SetMode(gin.ReleaseMode) @@ -174,6 +175,38 @@ func StartAPIServer(config *Config, } }) + router.GET("/exceptcache", func(c *gin.Context) { + c.IndentedJSON(http.StatusOK, gin.H{"length": exceptCache.Length(), "items": exceptCache.Backend}) + }) + + router.GET("/exceptcache/exists/:key", func(c *gin.Context) { + c.IndentedJSON(http.StatusOK, gin.H{"exists": exceptCache.Exists(c.Param("key"))}) + }) + + router.GET("/exceptcache/get/:key", func(c *gin.Context) { + if ok, _ := exceptCache.Get(c.Param("key")); !ok { + c.IndentedJSON(http.StatusOK, gin.H{"error": c.Param("key") + " not found"}) + } else { + c.IndentedJSON(http.StatusOK, gin.H{"success": ok}) + } + }) + + router.GET("/exceptcache/length", func(c *gin.Context) { + c.IndentedJSON(http.StatusOK, gin.H{"length": exceptCache.Length()}) + }) + + router.GET("/exceptcache/remove/:key", func(c *gin.Context) { + // Removes from exceptCache only. If the domain has already been queried and placed into MemoryCache, will need to wait until item is expired. + exceptCache.Remove(c.Param("key")) + c.IndentedJSON(http.StatusOK, gin.H{"success": true}) + }) + + router.GET("/exceptcache/set/:key", func(c *gin.Context) { + // MemoryBlockCache Set() always returns nil, so ignoring response. + _ = exceptCache.Set(c.Param("key"), true) + c.IndentedJSON(http.StatusOK, gin.H{"success": true}) + }) + router.GET("/questioncache", func(c *gin.Context) { highWater, err := strconv.ParseInt(c.DefaultQuery("highWater", "-1"), 10, 64) if err != nil { diff --git a/handler.go b/handler.go index 317bcd9..72c1dee 100644 --- a/handler.go +++ b/handler.go @@ -53,7 +53,7 @@ type DNSOperationData struct { } // NewHandler returns a new DNSHandler -func NewHandler(config *Config, blockCache *MemoryBlockCache, questionCache *MemoryQuestionCache) *DNSHandler { +func NewHandler(config *Config, blockCache *MemoryBlockCache, exceptCache *MemoryBlockCache, questionCache *MemoryQuestionCache) *DNSHandler { var ( clientConfig *dns.ClientConfig resolver *Resolver @@ -80,12 +80,12 @@ func NewHandler(config *Config, blockCache *MemoryBlockCache, questionCache *Mem active: true, } - go handler.do(config, blockCache, questionCache) + go handler.do(config, blockCache, exceptCache, questionCache) return handler } -func (h *DNSHandler) do(config *Config, blockCache *MemoryBlockCache, questionCache *MemoryQuestionCache) { +func (h *DNSHandler) do(config *Config, blockCache *MemoryBlockCache, exceptCache *MemoryBlockCache, questionCache *MemoryQuestionCache) { for { data, ok := <-h.requestChannel if !ok { @@ -154,7 +154,7 @@ func (h *DNSHandler) do(config *Config, blockCache *MemoryBlockCache, questionCa var drblblacklisted bool if IPQuery > 0 { - blacklisted = blockCache.Exists(Q.Qname) + blacklisted = !exceptCache.Exists(Q.Qname) && blockCache.Exists(Q.Qname) if config.UseDrbl > 0 { drblblacklisted = drblCheckHostname(Q.Qname) diff --git a/main.go b/main.go index 3451f51..e491320 100644 --- a/main.go +++ b/main.go @@ -24,6 +24,7 @@ var ( func reloadBlockCache(config *Config, blockCache *MemoryBlockCache, + exceptCache *MemoryBlockCache, questionCache *MemoryQuestionCache, drblPeers *drblpeer.DrblPeers, apiServer *http.Server, @@ -31,15 +32,15 @@ func reloadBlockCache(config *Config, reloadChan chan bool) (*MemoryBlockCache, *http.Server, error) { logger.Debug("Reloading the blockcache") - blockCache = PerformUpdate(config, true) + blockCache, exceptCache = PerformUpdate(config, true) server.Stop() if apiServer != nil { if err := apiServer.Shutdown(context.Background()); err != nil { logger.Debugf("error shutting down api server: %v", err) } } - server.Run(config, blockCache, questionCache) - apiServer, err := StartAPIServer(config, reloadChan, blockCache, questionCache) + server.Run(config, blockCache, exceptCache, questionCache) + apiServer, err := StartAPIServer(config, reloadChan, blockCache, exceptCache, questionCache) if err != nil { logger.Fatal(err) return nil, nil, err @@ -89,6 +90,8 @@ func main() { // BlockCache contains all blocked domains blockCache := &MemoryBlockCache{Backend: make(map[string]bool)} + // ExecptCache contains all exception domains (ignore blocks) + exceptCache := &MemoryBlockCache{Backend: make(map[string]bool)} // QuestionCache contains all queries to the dns server questionCache := makeQuestionCache(config.QuestionCacheCap) @@ -96,11 +99,11 @@ func main() { // The server will start with an empty blockcache soe we can dowload the lists if grimd is the // system's dns server. - server.Run(config, blockCache, questionCache) + server.Run(config, blockCache, exceptCache, questionCache) var apiServer *http.Server // Load the block cache, restart the server with the new context - blockCache, apiServer, err = reloadBlockCache(config, blockCache, questionCache, drblPeers, apiServer, server, reloadChan) + blockCache, apiServer, err = reloadBlockCache(config, blockCache, exceptCache, questionCache, drblPeers, apiServer, server, reloadChan) if err != nil { logger.Fatalf("Cannot start the API server %s", err) @@ -126,7 +129,7 @@ forever: } } case <-reloadChan: - blockCache, apiServer, err = reloadBlockCache(config, blockCache, questionCache, drblPeers, apiServer, server, reloadChan) + blockCache, apiServer, err = reloadBlockCache(config, blockCache, exceptCache, questionCache, drblPeers, apiServer, server, reloadChan) if err != nil { logger.Fatalf("Cannot start the API server %s", err) } diff --git a/server.go b/server.go index 8fdb10c..28c2363 100644 --- a/server.go +++ b/server.go @@ -19,9 +19,10 @@ type Server struct { // Run starts the server func (s *Server) Run(config *Config, blockCache *MemoryBlockCache, + exceptCache *MemoryBlockCache, questionCache *MemoryQuestionCache) { - s.handler = NewHandler(config, blockCache, questionCache) + s.handler = NewHandler(config, blockCache, exceptCache, questionCache) tcpHandler := dns.NewServeMux() tcpHandler.HandleFunc(".", s.handler.DoTCP) diff --git a/updater.go b/updater.go index d1731a4..44148a7 100644 --- a/updater.go +++ b/updater.go @@ -14,10 +14,9 @@ import ( ) var timesSeen = make(map[string]int) -var whitelist = make(map[string]bool) // Update downloads all the blocklists and imports them into the database -func update(blockCache *MemoryBlockCache, wlist []string, blist []string, sources []string) error { +func update(blockCache *MemoryBlockCache, exceptCache *MemoryBlockCache, wlist []string, blist []string, sources []string) error { if _, err := os.Stat("sources"); os.IsNotExist(err) { if err := os.Mkdir("sources", 0700); err != nil { return fmt.Errorf("error creating sources directory: %s", err) @@ -25,7 +24,7 @@ func update(blockCache *MemoryBlockCache, wlist []string, blist []string, source } for _, entry := range wlist { - whitelist[entry] = true + exceptCache.Set(entry, true) } for _, entry := range blist { @@ -100,7 +99,7 @@ func fetchSources(sources []string) error { } // UpdateBlockCache updates the BlockCache -func updateBlockCache(blockCache *MemoryBlockCache, sourceDirs []string) error { +func updateBlockCache(blockCache *MemoryBlockCache, exceptCache *MemoryBlockCache, sourceDirs []string) error { logger.Debugf("loading blocked domains from %d locations...\n", len(sourceDirs)) for _, dir := range sourceDirs { @@ -113,7 +112,7 @@ func updateBlockCache(blockCache *MemoryBlockCache, sourceDirs []string) error { if !f.IsDir() { fileName := filepath.FromSlash(path) - if err := parseHostFile(fileName, blockCache); err != nil { + if err := parseHostFile(fileName, blockCache, exceptCache); err != nil { return fmt.Errorf("error parsing hostfile %s", err) } } @@ -131,7 +130,7 @@ func updateBlockCache(blockCache *MemoryBlockCache, sourceDirs []string) error { return nil } -func parseHostFile(fileName string, blockCache *MemoryBlockCache) error { +func parseHostFile(fileName string, blockCache *MemoryBlockCache, exceptCache *MemoryBlockCache) error { file, err := os.Open(fileName) if err != nil { return fmt.Errorf("error opening file: %s", err) @@ -148,6 +147,7 @@ func parseHostFile(fileName string, blockCache *MemoryBlockCache) error { line := scanner.Text() line = strings.Split(line, "#")[0] line = strings.TrimSpace(line) + isException := strings.HasPrefix(line, "!") if len(line) > 0 { fields := strings.Fields(line) @@ -158,10 +158,19 @@ func parseHostFile(fileName string, blockCache *MemoryBlockCache) error { line = fields[0] } - if !blockCache.Exists(line) && !whitelist[line] { - err := blockCache.Set(line, true) - if err != nil { - logger.Critical(err) + if isException { + if !exceptCache.Exists(line) { + err := exceptCache.Set(line[1:], true) + if err != nil { + logger.Critical(err) + } + } + } else { + if !blockCache.Exists(line) && !exceptCache.Exists(line) { + err := blockCache.Set(line, true) + if err != nil { + logger.Critical(err) + } } } } @@ -176,16 +185,17 @@ func parseHostFile(fileName string, blockCache *MemoryBlockCache) error { // PerformUpdate updates the block cache by building a new one and swapping // it for the old cache. -func PerformUpdate(config *Config, forceUpdate bool) *MemoryBlockCache { +func PerformUpdate(config *Config, forceUpdate bool) (*MemoryBlockCache, *MemoryBlockCache) { newBlockCache := &MemoryBlockCache{Backend: make(map[string]bool), Special: make(map[string]*regexp.Regexp)} + newExceptCache := &MemoryBlockCache{Backend: make(map[string]bool)} if _, err := os.Stat("lists"); os.IsNotExist(err) || forceUpdate { - if err := update(newBlockCache, config.Whitelist, config.Blocklist, config.Sources); err != nil { + if err := update(newBlockCache, newExceptCache, config.Whitelist, config.Blocklist, config.Sources); err != nil { logger.Fatal(err) } } - if err := updateBlockCache(newBlockCache, config.SourceDirs); err != nil { + if err := updateBlockCache(newBlockCache, newExceptCache, config.SourceDirs); err != nil { logger.Fatal(err) } - return newBlockCache + return newBlockCache, newExceptCache }