Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add exception list feature #123

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ func isRunningInDockerContainer() bool {
func StartAPIServer(config *Config,
reloadChan chan bool,
blockCache *MemoryBlockCache,
exceptCache *MemoryBlockCache,
questionCache *MemoryQuestionCache) (*http.Server, error) {
if !config.APIDebug {
gin.SetMode(gin.ReleaseMode)
Expand Down Expand Up @@ -174,6 +175,38 @@ func StartAPIServer(config *Config,
}
})

router.GET("/exceptcache", func(c *gin.Context) {
c.IndentedJSON(http.StatusOK, gin.H{"length": exceptCache.Length(), "items": exceptCache.Backend})
})

router.GET("/exceptcache/exists/:key", func(c *gin.Context) {
c.IndentedJSON(http.StatusOK, gin.H{"exists": exceptCache.Exists(c.Param("key"))})
})

router.GET("/exceptcache/get/:key", func(c *gin.Context) {
if ok, _ := exceptCache.Get(c.Param("key")); !ok {
c.IndentedJSON(http.StatusOK, gin.H{"error": c.Param("key") + " not found"})
} else {
c.IndentedJSON(http.StatusOK, gin.H{"success": ok})
}
})

router.GET("/exceptcache/length", func(c *gin.Context) {
c.IndentedJSON(http.StatusOK, gin.H{"length": exceptCache.Length()})
})

router.GET("/exceptcache/remove/:key", func(c *gin.Context) {
// Removes from exceptCache only. If the domain has already been queried and placed into MemoryCache, will need to wait until item is expired.
exceptCache.Remove(c.Param("key"))
c.IndentedJSON(http.StatusOK, gin.H{"success": true})
})

router.GET("/exceptcache/set/:key", func(c *gin.Context) {
// MemoryBlockCache Set() always returns nil, so ignoring response.
_ = exceptCache.Set(c.Param("key"), true)
c.IndentedJSON(http.StatusOK, gin.H{"success": true})
})

router.GET("/questioncache", func(c *gin.Context) {
highWater, err := strconv.ParseInt(c.DefaultQuery("highWater", "-1"), 10, 64)
if err != nil {
Expand Down
8 changes: 4 additions & 4 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type DNSOperationData struct {
}

// NewHandler returns a new DNSHandler
func NewHandler(config *Config, blockCache *MemoryBlockCache, questionCache *MemoryQuestionCache) *DNSHandler {
func NewHandler(config *Config, blockCache *MemoryBlockCache, exceptCache *MemoryBlockCache, questionCache *MemoryQuestionCache) *DNSHandler {
var (
clientConfig *dns.ClientConfig
resolver *Resolver
Expand All @@ -80,12 +80,12 @@ func NewHandler(config *Config, blockCache *MemoryBlockCache, questionCache *Mem
active: true,
}

go handler.do(config, blockCache, questionCache)
go handler.do(config, blockCache, exceptCache, questionCache)

return handler
}

func (h *DNSHandler) do(config *Config, blockCache *MemoryBlockCache, questionCache *MemoryQuestionCache) {
func (h *DNSHandler) do(config *Config, blockCache *MemoryBlockCache, exceptCache *MemoryBlockCache, questionCache *MemoryQuestionCache) {
for {
data, ok := <-h.requestChannel
if !ok {
Expand Down Expand Up @@ -154,7 +154,7 @@ func (h *DNSHandler) do(config *Config, blockCache *MemoryBlockCache, questionCa
var drblblacklisted bool

if IPQuery > 0 {
blacklisted = blockCache.Exists(Q.Qname)
blacklisted = !exceptCache.Exists(Q.Qname) && blockCache.Exists(Q.Qname)

if config.UseDrbl > 0 {
drblblacklisted = drblCheckHostname(Q.Qname)
Expand Down
15 changes: 9 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,23 @@ var (

func reloadBlockCache(config *Config,
blockCache *MemoryBlockCache,
exceptCache *MemoryBlockCache,
questionCache *MemoryQuestionCache,
drblPeers *drblpeer.DrblPeers,
apiServer *http.Server,
server *Server,
reloadChan chan bool) (*MemoryBlockCache, *http.Server, error) {

logger.Debug("Reloading the blockcache")
blockCache = PerformUpdate(config, true)
blockCache, exceptCache = PerformUpdate(config, true)
server.Stop()
if apiServer != nil {
if err := apiServer.Shutdown(context.Background()); err != nil {
logger.Debugf("error shutting down api server: %v", err)
}
}
server.Run(config, blockCache, questionCache)
apiServer, err := StartAPIServer(config, reloadChan, blockCache, questionCache)
server.Run(config, blockCache, exceptCache, questionCache)
apiServer, err := StartAPIServer(config, reloadChan, blockCache, exceptCache, questionCache)
if err != nil {
logger.Fatal(err)
return nil, nil, err
Expand Down Expand Up @@ -89,18 +90,20 @@ func main() {

// BlockCache contains all blocked domains
blockCache := &MemoryBlockCache{Backend: make(map[string]bool)}
// ExecptCache contains all exception domains (ignore blocks)
exceptCache := &MemoryBlockCache{Backend: make(map[string]bool)}
// QuestionCache contains all queries to the dns server
questionCache := makeQuestionCache(config.QuestionCacheCap)

reloadChan := make(chan bool)

// The server will start with an empty blockcache soe we can dowload the lists if grimd is the
// system's dns server.
server.Run(config, blockCache, questionCache)
server.Run(config, blockCache, exceptCache, questionCache)

var apiServer *http.Server
// Load the block cache, restart the server with the new context
blockCache, apiServer, err = reloadBlockCache(config, blockCache, questionCache, drblPeers, apiServer, server, reloadChan)
blockCache, apiServer, err = reloadBlockCache(config, blockCache, exceptCache, questionCache, drblPeers, apiServer, server, reloadChan)

if err != nil {
logger.Fatalf("Cannot start the API server %s", err)
Expand All @@ -126,7 +129,7 @@ forever:
}
}
case <-reloadChan:
blockCache, apiServer, err = reloadBlockCache(config, blockCache, questionCache, drblPeers, apiServer, server, reloadChan)
blockCache, apiServer, err = reloadBlockCache(config, blockCache, exceptCache, questionCache, drblPeers, apiServer, server, reloadChan)
if err != nil {
logger.Fatalf("Cannot start the API server %s", err)
}
Expand Down
3 changes: 2 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ type Server struct {
// Run starts the server
func (s *Server) Run(config *Config,
blockCache *MemoryBlockCache,
exceptCache *MemoryBlockCache,
questionCache *MemoryQuestionCache) {

s.handler = NewHandler(config, blockCache, questionCache)
s.handler = NewHandler(config, blockCache, exceptCache, questionCache)

tcpHandler := dns.NewServeMux()
tcpHandler.HandleFunc(".", s.handler.DoTCP)
Expand Down
38 changes: 24 additions & 14 deletions updater.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,17 @@ import (
)

var timesSeen = make(map[string]int)
var whitelist = make(map[string]bool)

// Update downloads all the blocklists and imports them into the database
func update(blockCache *MemoryBlockCache, wlist []string, blist []string, sources []string) error {
func update(blockCache *MemoryBlockCache, exceptCache *MemoryBlockCache, wlist []string, blist []string, sources []string) error {
if _, err := os.Stat("sources"); os.IsNotExist(err) {
if err := os.Mkdir("sources", 0700); err != nil {
return fmt.Errorf("error creating sources directory: %s", err)
}
}

for _, entry := range wlist {
whitelist[entry] = true
exceptCache.Set(entry, true)
}

for _, entry := range blist {
Expand Down Expand Up @@ -100,7 +99,7 @@ func fetchSources(sources []string) error {
}

// UpdateBlockCache updates the BlockCache
func updateBlockCache(blockCache *MemoryBlockCache, sourceDirs []string) error {
func updateBlockCache(blockCache *MemoryBlockCache, exceptCache *MemoryBlockCache, sourceDirs []string) error {
logger.Debugf("loading blocked domains from %d locations...\n", len(sourceDirs))

for _, dir := range sourceDirs {
Expand All @@ -113,7 +112,7 @@ func updateBlockCache(blockCache *MemoryBlockCache, sourceDirs []string) error {
if !f.IsDir() {
fileName := filepath.FromSlash(path)

if err := parseHostFile(fileName, blockCache); err != nil {
if err := parseHostFile(fileName, blockCache, exceptCache); err != nil {
return fmt.Errorf("error parsing hostfile %s", err)
}
}
Expand All @@ -131,7 +130,7 @@ func updateBlockCache(blockCache *MemoryBlockCache, sourceDirs []string) error {
return nil
}

func parseHostFile(fileName string, blockCache *MemoryBlockCache) error {
func parseHostFile(fileName string, blockCache *MemoryBlockCache, exceptCache *MemoryBlockCache) error {
file, err := os.Open(fileName)
if err != nil {
return fmt.Errorf("error opening file: %s", err)
Expand All @@ -148,6 +147,7 @@ func parseHostFile(fileName string, blockCache *MemoryBlockCache) error {
line := scanner.Text()
line = strings.Split(line, "#")[0]
line = strings.TrimSpace(line)
isException := strings.HasPrefix(line, "!")

if len(line) > 0 {
fields := strings.Fields(line)
Expand All @@ -158,10 +158,19 @@ func parseHostFile(fileName string, blockCache *MemoryBlockCache) error {
line = fields[0]
}

if !blockCache.Exists(line) && !whitelist[line] {
err := blockCache.Set(line, true)
if err != nil {
logger.Critical(err)
if isException {
if !exceptCache.Exists(line) {
err := exceptCache.Set(line[1:], true)
if err != nil {
logger.Critical(err)
}
}
} else {
if !blockCache.Exists(line) && !exceptCache.Exists(line) {
err := blockCache.Set(line, true)
if err != nil {
logger.Critical(err)
}
}
}
}
Expand All @@ -176,16 +185,17 @@ func parseHostFile(fileName string, blockCache *MemoryBlockCache) error {

// PerformUpdate updates the block cache by building a new one and swapping
// it for the old cache.
func PerformUpdate(config *Config, forceUpdate bool) *MemoryBlockCache {
func PerformUpdate(config *Config, forceUpdate bool) (*MemoryBlockCache, *MemoryBlockCache) {
newBlockCache := &MemoryBlockCache{Backend: make(map[string]bool), Special: make(map[string]*regexp.Regexp)}
newExceptCache := &MemoryBlockCache{Backend: make(map[string]bool)}
if _, err := os.Stat("lists"); os.IsNotExist(err) || forceUpdate {
if err := update(newBlockCache, config.Whitelist, config.Blocklist, config.Sources); err != nil {
if err := update(newBlockCache, newExceptCache, config.Whitelist, config.Blocklist, config.Sources); err != nil {
logger.Fatal(err)
}
}
if err := updateBlockCache(newBlockCache, config.SourceDirs); err != nil {
if err := updateBlockCache(newBlockCache, newExceptCache, config.SourceDirs); err != nil {
logger.Fatal(err)
}

return newBlockCache
return newBlockCache, newExceptCache
}
Loading