From 8a0897a8cccd8f521e034951203d7b8553609b7f Mon Sep 17 00:00:00 2001 From: dmachard <5562930+dmachard@users.noreply.github.com> Date: Mon, 15 May 2023 21:14:30 +0200 Subject: [PATCH 1/2] convert to array --- dnsutils/message.go | 1 + loggers/restapi.go | 82 ++++++++++++++++++--- loggers/restapi_test.go | 155 ++++++++++++++++++++++++++++++++++------ 3 files changed, 206 insertions(+), 32 deletions(-) diff --git a/dnsutils/message.go b/dnsutils/message.go index db6ef9f4..8b152d60 100644 --- a/dnsutils/message.go +++ b/dnsutils/message.go @@ -157,6 +157,7 @@ type TransformSuspicious struct { UnallowedChars bool `json:"unallowed-chars" msgpack:"unallowed-chars"` UncommonQtypes bool `json:"uncommon-qtypes" msgpack:"uncommon-qtypes"` ExcessiveNumberLabels bool `json:"excessive-number-labels" msgpack:"excessive-number-labels"` + Domain string `json:"domain,omitempty" msgpack:"-"` } type TransformPublicSuffix struct { diff --git a/loggers/restapi.go b/loggers/restapi.go index 0050c527..f4ebafbe 100644 --- a/loggers/restapi.go +++ b/loggers/restapi.go @@ -37,6 +37,21 @@ type HitsUniq struct { Suspicious map[string]*dnsutils.TransformSuspicious } +type StreamHit struct { + Stream string `json:"stream"` + Hit int `json:"hit"` +} + +type DomainHit struct { + Domain string `json:"domain"` + Hit int `json:"hit"` +} + +type AddressHit struct { + Address string `json:"address"` + Hit int `json:"hit"` +} + type RestAPI struct { done chan bool done_api chan bool @@ -255,7 +270,14 @@ func (s *RestAPI) GetTLDsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - json.NewEncoder(w).Encode(s.HitsUniq.PublicSuffixes) + // return as array + dataArray := []DomainHit{} + for tld, hit := range s.HitsUniq.PublicSuffixes { + dataArray = append(dataArray, DomainHit{Domain: tld, Hit: hit}) + } + + // encode + json.NewEncoder(w).Encode(dataArray) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -274,7 +296,14 @@ func (s *RestAPI) GetClientsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - json.NewEncoder(w).Encode(s.HitsUniq.Clients) + // return as array + dataArray := []AddressHit{} + for address, hit := range s.HitsUniq.Clients { + dataArray = append(dataArray, AddressHit{Address: address, Hit: hit}) + } + + // encode + json.NewEncoder(w).Encode(dataArray) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -293,7 +322,14 @@ func (s *RestAPI) GetDomainsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - json.NewEncoder(w).Encode(s.HitsUniq.Domains) + // return as array + dataArray := []DomainHit{} + for domain, hit := range s.HitsUniq.Domains { + dataArray = append(dataArray, DomainHit{Domain: domain, Hit: hit}) + } + + // encode + json.NewEncoder(w).Encode(dataArray) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -312,7 +348,15 @@ func (s *RestAPI) GetNxDomainsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - json.NewEncoder(w).Encode(s.HitsUniq.NxDomains) + // convert to array + dataArray := []DomainHit{} + for domain, hit := range s.HitsUniq.NxDomains { + dataArray = append(dataArray, DomainHit{Domain: domain, Hit: hit}) + } + + // encode + json.NewEncoder(w).Encode(dataArray) + default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -331,7 +375,14 @@ func (s *RestAPI) GetSfDomainsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - json.NewEncoder(w).Encode(s.HitsUniq.SfDomains) + // return as array + dataArray := []DomainHit{} + for domain, hit := range s.HitsUniq.SfDomains { + dataArray = append(dataArray, DomainHit{Domain: domain, Hit: hit}) + } + + // encode + json.NewEncoder(w).Encode(dataArray) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -350,7 +401,15 @@ func (s *RestAPI) GetSuspiciousHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - json.NewEncoder(w).Encode(s.HitsUniq.Suspicious) + // return as array + dataArray := []*dnsutils.TransformSuspicious{} + for domain, suspicious := range s.HitsUniq.Suspicious { + suspicious.Domain = domain + dataArray = append(dataArray, suspicious) + } + + // encode + json.NewEncoder(w).Encode(dataArray) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -433,7 +492,13 @@ func (s *RestAPI) GetStreamsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - json.NewEncoder(w).Encode(s.Streams) + + dataArray := []StreamHit{} + for stream, hit := range s.Streams { + dataArray = append(dataArray, StreamHit{Stream: stream, Hit: hit}) + } + + json.NewEncoder(w).Encode(dataArray) default: http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } @@ -555,10 +620,9 @@ func (s *RestAPI) ListenAndServe() { mux.HandleFunc("/clients", s.GetClientsHandler) mux.HandleFunc("/clients/top", s.GetTopClientsHandler) mux.HandleFunc("/domains", s.GetDomainsHandler) + mux.HandleFunc("/domains/servfail", s.GetSfDomainsHandler) mux.HandleFunc("/domains/top", s.GetTopDomainsHandler) - mux.HandleFunc("/domains/nx", s.GetNxDomainsHandler) mux.HandleFunc("/domains/nx/top", s.GetTopNxDomainsHandler) - mux.HandleFunc("/domains/servfail", s.GetSfDomainsHandler) mux.HandleFunc("/domains/servfail/top", s.GetTopSfDomainsHandler) mux.HandleFunc("/suspicious", s.GetSuspiciousHandler) mux.HandleFunc("/search", s.GetSearchHandler) diff --git a/loggers/restapi_test.go b/loggers/restapi_test.go index 677a7e37..c75d6a05 100644 --- a/loggers/restapi_test.go +++ b/loggers/restapi_test.go @@ -11,7 +11,7 @@ import ( "github.com/dmachard/go-logger" ) -func TestRestAPIBadBasicAuth(t *testing.T) { +func TestRestAPI_BadBasicAuth(t *testing.T) { // init the logger config := dnsutils.GetFakeConfig() g := NewRestAPI(config, logger.New(false), "dev", "test") @@ -57,7 +57,7 @@ func TestRestAPIBadBasicAuth(t *testing.T) { } } -func TestWebServerGet(t *testing.T) { +func TestRestAPI_MethodNotAllowed(t *testing.T) { // init the logger config := dnsutils.GetFakeConfig() g := NewRestAPI(config, logger.New(false), "dev", "test") @@ -82,12 +82,12 @@ func TestWebServerGet(t *testing.T) { statusCode int }{ { - name: "get clients", - uri: "/clients", - handler: g.GetClientsHandler, + name: "post streams refused", + uri: "/streams", + handler: g.GetStreamsHandler, method: http.MethodGet, - want: `{"1.2.3.4":1}`, - statusCode: http.StatusOK, + want: `Method not allowed`, + statusCode: http.StatusMethodNotAllowed, }, { name: "post clients refused", @@ -97,14 +97,6 @@ func TestWebServerGet(t *testing.T) { want: "Method not allowed", statusCode: http.StatusMethodNotAllowed, }, - { - name: "get tlds", - uri: "/tlds", - handler: g.GetTLDsHandler, - method: http.MethodGet, - want: `{"collector":1}`, - statusCode: http.StatusOK, - }, { name: "post tlds refused", uri: "/tlds", @@ -113,14 +105,6 @@ func TestWebServerGet(t *testing.T) { want: `Method not allowed`, statusCode: http.StatusMethodNotAllowed, }, - { - name: "get domains", - uri: "/domains", - handler: g.GetDomainsHandler, - method: http.MethodGet, - want: `{"dns.collector":1}`, - statusCode: http.StatusOK, - }, { name: "post domains refused", uri: "/domains", @@ -154,3 +138,128 @@ func TestWebServerGet(t *testing.T) { }) } } + +func TestRestAPI_Get(t *testing.T) { + // init the logger + config := dnsutils.GetFakeConfig() + g := NewRestAPI(config, logger.New(false), "dev", "test") + + tt := []struct { + name string + uri string + handler func(w http.ResponseWriter, r *http.Request) + method string + want string + dm dnsutils.DnsMessage + dmRcode string + statusCode int + }{ + { + name: "streams", + uri: "/streams", + handler: g.GetStreamsHandler, + method: http.MethodGet, + want: `\[\{"stream":"collector","hit":1\}\]`, + statusCode: http.StatusOK, + dm: dnsutils.GetFakeDnsMessage(), + dmRcode: "NOERROR", + }, + { + name: "clients", + uri: "/clients", + handler: g.GetClientsHandler, + method: http.MethodGet, + want: `\[\{"address":"1.2.3.4","hit":2\}\]`, + statusCode: http.StatusOK, + dm: dnsutils.GetFakeDnsMessage(), + dmRcode: "NOERROR", + }, + { + name: "domains", + uri: "/domains", + handler: g.GetDomainsHandler, + method: http.MethodGet, + want: `\[\{"domain":"dns.collector","hit":3\}\]`, + statusCode: http.StatusOK, + dm: dnsutils.GetFakeDnsMessage(), + dmRcode: "NOERROR", + }, + { + name: "nx domains", + uri: "/domains/nx", + handler: g.GetNxDomainsHandler, + method: http.MethodGet, + want: `\[\{"domain":"dns.collector","hit":1\}\]`, + statusCode: http.StatusOK, + dm: dnsutils.GetFakeDnsMessage(), + dmRcode: "NXDOMAIN", + }, + { + name: "servfail domains", + uri: "/domains/servfail", + handler: g.GetSfDomainsHandler, + method: http.MethodGet, + want: `\[\{"domain":"dns.collector","hit":1\}\]`, + statusCode: http.StatusOK, + dm: dnsutils.GetFakeDnsMessage(), + dmRcode: "SERVFAIL", + }, + { + name: "tlds", + uri: "/tlds", + handler: g.GetTLDsHandler, + method: http.MethodGet, + want: `\[\{"domain":".com","hit":1\}\]`, + statusCode: http.StatusOK, + dm: dnsutils.GetFakeDnsMessage(), + dmRcode: "NOERROR", + }, + { + name: "suspicious", + uri: "/suspicious", + handler: g.GetSuspiciousHandler, + method: http.MethodGet, + want: `\[\{"score":1,"malformed-pkt":false,"large-pkt":false,"long-domain":false,"slow-domain":false,"unallowed-chars":false,"uncommon-qtypes":false,"excessive-number-labels":false,"domain":"dns:collector"\}\]`, + statusCode: http.StatusOK, + dm: dnsutils.GetFakeDnsMessage(), + dmRcode: "NOERROR", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + // record the dns message + dm := tc.dm + dm.DNS.Rcode = tc.dmRcode + if tc.name == "tlds" { + dm.PublicSuffix = &dnsutils.TransformPublicSuffix{} + dm.PublicSuffix.QnamePublicSuffix = ".com" + } + + if tc.name == "suspicious" { + dm.DNS.Qname = "dns:collector" + dm.Suspicious = &dnsutils.TransformSuspicious{Score: 1} + } + g.RecordDnsMessage(dm) + + // init httptest + request := httptest.NewRequest(tc.method, tc.uri, strings.NewReader("")) + request.SetBasicAuth(config.Loggers.RestAPI.BasicAuthLogin, config.Loggers.RestAPI.BasicAuthPwd) + responseRecorder := httptest.NewRecorder() + + // call handler + tc.handler(responseRecorder, request) + + // checking status code + if responseRecorder.Code != tc.statusCode { + t.Errorf("Want status '%d', got '%d'", tc.statusCode, responseRecorder.Code) + } + + // checking content + response := strings.TrimSpace(responseRecorder.Body.String()) + if regexp.MustCompile(tc.want).MatchString(response) != true { + t.Errorf("Want '%s', got '%s'", tc.want, response) + } + }) + } +} From 666ebd73bbb2120662be6fbf183ea4b237039f38 Mon Sep 17 00:00:00 2001 From: dmachard <5562930+dmachard@users.noreply.github.com> Date: Mon, 15 May 2023 22:39:07 +0200 Subject: [PATCH 2/2] be more generic /search --- doc/swagger.yml | 28 ++++------- loggers/restapi.go | 104 ++++++++++++++++------------------------ loggers/restapi_test.go | 54 +++++++++++++++++---- 3 files changed, 95 insertions(+), 91 deletions(-) diff --git a/doc/swagger.yml b/doc/swagger.yml index efb29878..4313b257 100644 --- a/doc/swagger.yml +++ b/doc/swagger.yml @@ -1,13 +1,13 @@ openapi: 3.0.2 info: title: Swagger for DNS-collector tool - version: 0.28.0 + version: 0.32.0 description: This is a swagger for the API of the DNS-collector. contact: email: d.machard@gmail.com license: name: MIT - url: 'https://github.com/dmachard/go-dns-collector/blob/main/LICENSE' + url: 'https://github.com/dmachard/go-dnscollector/blob/main/LICENSE' x-logo: url: '' servers: @@ -16,29 +16,19 @@ paths: /search: get: parameters: - - in: query - name: stream_id + - in: filter + name: filter schema: type: string - description: stream identity name - - in: query - name: query_ip - schema: - type: string - description: query ip to search - - in: query - name: query_name - schema: - type: string - description: query name to search + description: domain or address to search responses: '200': - description: Return list of domains founded + description: Return list of domains or addresses founded content: text/plain: schema: type: string - summary: Return a list of domains + summary: Return a list of domains or addresses /streams: get: responses: @@ -153,7 +143,7 @@ paths: get: responses: '200': - description: Rerurn suspicious domains list + description: Return suspicious domains list content: text/plain: schema: @@ -161,4 +151,4 @@ paths: summary: Return suspicious domains list security: [] externalDocs: - url: 'https://github.com/dmachard/go-dns-collector' \ No newline at end of file + url: 'https://github.com/dmachard/go-dnscollector' \ No newline at end of file diff --git a/loggers/restapi.go b/loggers/restapi.go index f4ebafbe..91e12a4c 100644 --- a/loggers/restapi.go +++ b/loggers/restapi.go @@ -37,19 +37,9 @@ type HitsUniq struct { Suspicious map[string]*dnsutils.TransformSuspicious } -type StreamHit struct { - Stream string `json:"stream"` - Hit int `json:"hit"` -} - -type DomainHit struct { - Domain string `json:"domain"` - Hit int `json:"hit"` -} - -type AddressHit struct { - Address string `json:"address"` - Hit int `json:"hit"` +type KeyHit struct { + Key string `json:"key"` + Hit int `json:"hit"` } type RestAPI struct { @@ -271,9 +261,9 @@ func (s *RestAPI) GetTLDsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // return as array - dataArray := []DomainHit{} + dataArray := []KeyHit{} for tld, hit := range s.HitsUniq.PublicSuffixes { - dataArray = append(dataArray, DomainHit{Domain: tld, Hit: hit}) + dataArray = append(dataArray, KeyHit{Key: tld, Hit: hit}) } // encode @@ -297,9 +287,9 @@ func (s *RestAPI) GetClientsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // return as array - dataArray := []AddressHit{} + dataArray := []KeyHit{} for address, hit := range s.HitsUniq.Clients { - dataArray = append(dataArray, AddressHit{Address: address, Hit: hit}) + dataArray = append(dataArray, KeyHit{Key: address, Hit: hit}) } // encode @@ -323,9 +313,9 @@ func (s *RestAPI) GetDomainsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // return as array - dataArray := []DomainHit{} + dataArray := []KeyHit{} for domain, hit := range s.HitsUniq.Domains { - dataArray = append(dataArray, DomainHit{Domain: domain, Hit: hit}) + dataArray = append(dataArray, KeyHit{Key: domain, Hit: hit}) } // encode @@ -349,9 +339,9 @@ func (s *RestAPI) GetNxDomainsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // convert to array - dataArray := []DomainHit{} + dataArray := []KeyHit{} for domain, hit := range s.HitsUniq.NxDomains { - dataArray = append(dataArray, DomainHit{Domain: domain, Hit: hit}) + dataArray = append(dataArray, KeyHit{Key: domain, Hit: hit}) } // encode @@ -376,9 +366,9 @@ func (s *RestAPI) GetSfDomainsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: // return as array - dataArray := []DomainHit{} + dataArray := []KeyHit{} for domain, hit := range s.HitsUniq.SfDomains { - dataArray = append(dataArray, DomainHit{Domain: domain, Hit: hit}) + dataArray = append(dataArray, KeyHit{Key: domain, Hit: hit}) } // encode @@ -427,53 +417,39 @@ func (s *RestAPI) GetSearchHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - streamId := r.URL.Query()["stream_id"] - queryIp := r.URL.Query()["query_ip"] - queryName := r.URL.Query()["query_name"] - - if len(streamId) == 0 && len(queryIp) == 0 && len(queryName) == 0 { + filter := r.URL.Query()["filter"] + if len(filter) == 0 { http.Error(w, "Arguments are missing", http.StatusBadRequest) } - // search in a stream - if len(streamId) == 1 { - if _, exists := s.HitsStream.Streams[streamId[0]]; exists { - stream := s.HitsStream.Streams[streamId[0]] - - if len(queryIp) == 1 && len(queryName) == 1 { - if _, exists := stream.Clients[queryIp[0]]; exists { - client := stream.Clients[queryIp[0]] - if _, domainExists := client.Hits[queryName[0]]; domainExists { - w.Header().Set("Content-Type", "application/text") - w.Write([]byte(strconv.Itoa(client.Hits[queryName[0]]))) - } else { - http.Error(w, "{\"error\": \"Query Name not found\"}", http.StatusNotFound) - } - } else { - http.Error(w, "{\"error\": \"Query IP not found\"}", http.StatusNotFound) - } + dataArray := []KeyHit{} - } else if len(queryIp) == 1 { - if _, exists := stream.Clients[queryIp[0]]; exists { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(stream.Clients[queryIp[0]]) - } else { - http.Error(w, "{\"error\": \"Query IP not found\"}", http.StatusNotFound) - } + // search by IP + for _, search := range s.HitsStream.Streams { + userHits, clientExists := search.Clients[filter[0]] + if clientExists { + for domain, hit := range userHits.Hits { + dataArray = append(dataArray, KeyHit{Key: domain, Hit: hit}) + } + } + } - } else if len(queryName) == 1 { - if _, exists := stream.Domains[queryName[0]]; exists { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(stream.Domains[queryName[0]]) - } else { - http.Error(w, "{\"error\": \"Query Name not found\"}", http.StatusNotFound) + // search by domain + if len(dataArray) == 0 { + for _, search := range s.HitsStream.Streams { + domainHists, domainExists := search.Domains[filter[0]] + if domainExists { + for addr, hit := range domainHists.Hits { + dataArray = append(dataArray, KeyHit{Key: addr, Hit: hit}) } } - - } else { - http.Error(w, "{\"error\": \"Stream ID not Found\"}", http.StatusNotFound) } } + + // encode to json + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(dataArray) + default: http.Error(w, "{\"error\": \"Method not allowed\"}", http.StatusMethodNotAllowed) } @@ -493,9 +469,9 @@ func (s *RestAPI) GetStreamsHandler(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - dataArray := []StreamHit{} + dataArray := []KeyHit{} for stream, hit := range s.Streams { - dataArray = append(dataArray, StreamHit{Stream: stream, Hit: hit}) + dataArray = append(dataArray, KeyHit{Key: stream, Hit: hit}) } json.NewEncoder(w).Encode(dataArray) @@ -625,7 +601,7 @@ func (s *RestAPI) ListenAndServe() { mux.HandleFunc("/domains/nx/top", s.GetTopNxDomainsHandler) mux.HandleFunc("/domains/servfail/top", s.GetTopSfDomainsHandler) mux.HandleFunc("/suspicious", s.GetSuspiciousHandler) - mux.HandleFunc("/search", s.GetSearchHandler) + mux.HandleFunc("/search/address", s.GetSearchHandler) var err error var listener net.Listener diff --git a/loggers/restapi_test.go b/loggers/restapi_test.go index c75d6a05..bc852ee6 100644 --- a/loggers/restapi_test.go +++ b/loggers/restapi_test.go @@ -85,7 +85,7 @@ func TestRestAPI_MethodNotAllowed(t *testing.T) { name: "post streams refused", uri: "/streams", handler: g.GetStreamsHandler, - method: http.MethodGet, + method: http.MethodPost, want: `Method not allowed`, statusCode: http.StatusMethodNotAllowed, }, @@ -113,6 +113,14 @@ func TestRestAPI_MethodNotAllowed(t *testing.T) { want: `Method not allowed`, statusCode: http.StatusMethodNotAllowed, }, + { + name: "post search refused", + uri: "/search", + handler: g.GetSearchHandler, + method: http.MethodPost, + want: `Method not allowed`, + statusCode: http.StatusMethodNotAllowed, + }, } for _, tc := range tt { @@ -139,7 +147,7 @@ func TestRestAPI_MethodNotAllowed(t *testing.T) { } } -func TestRestAPI_Get(t *testing.T) { +func TestRestAPI_GetMethod(t *testing.T) { // init the logger config := dnsutils.GetFakeConfig() g := NewRestAPI(config, logger.New(false), "dev", "test") @@ -159,7 +167,7 @@ func TestRestAPI_Get(t *testing.T) { uri: "/streams", handler: g.GetStreamsHandler, method: http.MethodGet, - want: `\[\{"stream":"collector","hit":1\}\]`, + want: `\[\{"key":"collector","hit":1\}\]`, statusCode: http.StatusOK, dm: dnsutils.GetFakeDnsMessage(), dmRcode: "NOERROR", @@ -169,7 +177,7 @@ func TestRestAPI_Get(t *testing.T) { uri: "/clients", handler: g.GetClientsHandler, method: http.MethodGet, - want: `\[\{"address":"1.2.3.4","hit":2\}\]`, + want: `\[\{"key":"1.2.3.4","hit":2\}\]`, statusCode: http.StatusOK, dm: dnsutils.GetFakeDnsMessage(), dmRcode: "NOERROR", @@ -179,7 +187,7 @@ func TestRestAPI_Get(t *testing.T) { uri: "/domains", handler: g.GetDomainsHandler, method: http.MethodGet, - want: `\[\{"domain":"dns.collector","hit":3\}\]`, + want: `\[\{"key":"dns.collector","hit":3\}\]`, statusCode: http.StatusOK, dm: dnsutils.GetFakeDnsMessage(), dmRcode: "NOERROR", @@ -189,7 +197,7 @@ func TestRestAPI_Get(t *testing.T) { uri: "/domains/nx", handler: g.GetNxDomainsHandler, method: http.MethodGet, - want: `\[\{"domain":"dns.collector","hit":1\}\]`, + want: `\[\{"key":"dns.collector","hit":1\}\]`, statusCode: http.StatusOK, dm: dnsutils.GetFakeDnsMessage(), dmRcode: "NXDOMAIN", @@ -199,7 +207,7 @@ func TestRestAPI_Get(t *testing.T) { uri: "/domains/servfail", handler: g.GetSfDomainsHandler, method: http.MethodGet, - want: `\[\{"domain":"dns.collector","hit":1\}\]`, + want: `\[\{"key":"dns.collector","hit":1\}\]`, statusCode: http.StatusOK, dm: dnsutils.GetFakeDnsMessage(), dmRcode: "SERVFAIL", @@ -209,7 +217,7 @@ func TestRestAPI_Get(t *testing.T) { uri: "/tlds", handler: g.GetTLDsHandler, method: http.MethodGet, - want: `\[\{"domain":".com","hit":1\}\]`, + want: `\[\{"key":".com","hit":1\}\]`, statusCode: http.StatusOK, dm: dnsutils.GetFakeDnsMessage(), dmRcode: "NOERROR", @@ -224,6 +232,36 @@ func TestRestAPI_Get(t *testing.T) { dm: dnsutils.GetFakeDnsMessage(), dmRcode: "NOERROR", }, + { + name: "search_by_domain", + uri: "/search?filter=dns.collector", + handler: g.GetSearchHandler, + method: http.MethodGet, + want: `\[\{"key":"1.2.3.4","hit":7\}\]`, + statusCode: http.StatusOK, + dm: dnsutils.GetFakeDnsMessage(), + dmRcode: "NOERROR", + }, + { + name: "search_by_ip", + uri: "/search?filter=1.2.3.4", + handler: g.GetSearchHandler, + method: http.MethodGet, + want: `\[\{"key":"dns.collector","hit":8},{"key":"dns:collector","hit":1\}\]`, + statusCode: http.StatusOK, + dm: dnsutils.GetFakeDnsMessage(), + dmRcode: "NOERROR", + }, + { + name: "search_not_found", + uri: "/search?filter=notfound.collector", + handler: g.GetSearchHandler, + method: http.MethodGet, + want: `\[\]`, + statusCode: http.StatusOK, + dm: dnsutils.GetFakeDnsMessage(), + dmRcode: "NOERROR", + }, } for _, tc := range tt {