Skip to content

Commit

Permalink
Merge pull request #200 from hashicorp/f-dns
Browse files Browse the repository at this point in the history
Adding support for DNS TTLs and stale reads
  • Loading branch information
armon committed Jun 9, 2014
2 parents 71a654a + bd4610b commit 93fb12e
Show file tree
Hide file tree
Showing 9 changed files with 545 additions and 34 deletions.
4 changes: 2 additions & 2 deletions command/agent/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ func (c *Command) setupAgent(config *Config, logOutput io.Writer, logWriter *log
return err
}

server, err := NewDNSServer(agent, logOutput, config.Domain,
dnsAddr.String(), config.DNSRecursor)
server, err := NewDNSServer(agent, &config.DNSConfig, logOutput,
config.Domain, dnsAddr.String(), config.DNSRecursor)
if err != nil {
agent.Shutdown()
c.Ui.Error(fmt.Sprintf("Error starting dns server: %s", err))
Expand Down
82 changes: 82 additions & 0 deletions command/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,35 @@ type PortConfig struct {
Server int // Server internal RPC
}

// DNSConfig is used to fine tune the DNS sub-system.
// It can be used to control cache values, and stale
// reads
type DNSConfig struct {
// NodeTTL provides the TTL value for a node query
NodeTTL time.Duration `mapstructure:"-"`
NodeTTLRaw string `mapstructure:"node_ttl" json:"-"`

// ServiceTTL provides the TTL value for a service
// query for given service. The "*" wildcard can be used
// to set a default for all services.
ServiceTTL map[string]time.Duration `mapstructure:"-"`
ServiceTTLRaw map[string]string `mapstructure:"service_ttl" json:"-"`

// AllowStale is used to enable lookups with stale
// data. This gives horizontal read scalability since
// any Consul server can service the query instead of
// only the leader.
AllowStale bool `mapstructure:"allow_stale"`

// MaxStale is used to bound how stale of a result is
// accepted for a DNS lookup. This can be used with
// AllowStale to limit how old of a value is served up.
// If the stale result exceeds this, another non-stale
// stale read is performed.
MaxStale time.Duration `mapstructure:"-"`
MaxStaleRaw string `mapstructure:"max_stale" json:"-"`
}

// Config is the configuration that can be set for an Agent.
// Some of this is configurable as CLI flags, but most must
// be set using a configuration file.
Expand All @@ -50,6 +79,9 @@ type Config struct {
// resolve non-consul domains
DNSRecursor string `mapstructure:"recursor"`

// DNS configuration
DNSConfig DNSConfig `mapstructure:"dns_config"`

// Domain is the DNS domain for the records. Defaults to "consul."
Domain string `mapstructure:"domain"`

Expand Down Expand Up @@ -185,6 +217,9 @@ func DefaultConfig() *Config {
SerfWan: consul.DefaultWANSerfPort,
Server: 8300,
},
DNSConfig: DNSConfig{
MaxStale: 5 * time.Second,
},
Protocol: consul.ProtocolVersionMax,
AEInterval: time.Minute,
}
Expand Down Expand Up @@ -244,6 +279,36 @@ func DecodeConfig(r io.Reader) (*Config, error) {
return nil, err
}

// Handle time conversions
if raw := result.DNSConfig.NodeTTLRaw; raw != "" {
dur, err := time.ParseDuration(raw)
if err != nil {
return nil, fmt.Errorf("NodeTTL invalid: %v", err)
}
result.DNSConfig.NodeTTL = dur
}

if raw := result.DNSConfig.MaxStaleRaw; raw != "" {
dur, err := time.ParseDuration(raw)
if err != nil {
return nil, fmt.Errorf("MaxStale invalid: %v", err)
}
result.DNSConfig.MaxStale = dur
}

if len(result.DNSConfig.ServiceTTLRaw) != 0 {
if result.DNSConfig.ServiceTTL == nil {
result.DNSConfig.ServiceTTL = make(map[string]time.Duration)
}
for service, raw := range result.DNSConfig.ServiceTTLRaw {
dur, err := time.ParseDuration(raw)
if err != nil {
return nil, fmt.Errorf("ServiceTTL %s invalid: %v", service, err)
}
result.DNSConfig.ServiceTTL[service] = dur
}
}

return &result, nil
}

Expand Down Expand Up @@ -454,6 +519,23 @@ func MergeConfig(a, b *Config) *Config {
if b.RejoinAfterLeave {
result.RejoinAfterLeave = true
}
if b.DNSConfig.NodeTTL != 0 {
result.DNSConfig.NodeTTL = b.DNSConfig.NodeTTL
}
if len(b.DNSConfig.ServiceTTL) != 0 {
if result.DNSConfig.ServiceTTL == nil {
result.DNSConfig.ServiceTTL = make(map[string]time.Duration)
}
for service, dur := range b.DNSConfig.ServiceTTL {
result.DNSConfig.ServiceTTL[service] = dur
}
}
if b.DNSConfig.AllowStale {
result.DNSConfig.AllowStale = true
}
if b.DNSConfig.MaxStale != 0 {
result.DNSConfig.MaxStale = b.DNSConfig.MaxStale
}

// Copy the start join addresses
result.StartJoin = make([]string, 0, len(a.StartJoin)+len(b.StartJoin))
Expand Down
50 changes: 46 additions & 4 deletions command/agent/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,40 @@ func TestDecodeConfig(t *testing.T) {
if !config.RejoinAfterLeave {
t.Fatalf("bad: %#v", config)
}

// DNS node ttl, max stale
input = `{"dns_config": {"node_ttl": "5s", "max_stale": "15s", "allow_stale": true}}`
config, err = DecodeConfig(bytes.NewReader([]byte(input)))
if err != nil {
t.Fatalf("err: %s", err)
}

if config.DNSConfig.NodeTTL != 5*time.Second {
t.Fatalf("bad: %#v", config)
}
if config.DNSConfig.MaxStale != 15*time.Second {
t.Fatalf("bad: %#v", config)
}
if !config.DNSConfig.AllowStale {
t.Fatalf("bad: %#v", config)
}

// DNS service ttl
input = `{"dns_config": {"service_ttl": {"*": "1s", "api": "10s", "web": "30s"}}}`
config, err = DecodeConfig(bytes.NewReader([]byte(input)))
if err != nil {
t.Fatalf("err: %s", err)
}

if config.DNSConfig.ServiceTTL["*"] != time.Second {
t.Fatalf("bad: %#v", config)
}
if config.DNSConfig.ServiceTTL["api"] != 10*time.Second {
t.Fatalf("bad: %#v", config)
}
if config.DNSConfig.ServiceTTL["web"] != 30*time.Second {
t.Fatalf("bad: %#v", config)
}
}

func TestDecodeConfig_Service(t *testing.T) {
Expand Down Expand Up @@ -391,10 +425,18 @@ func TestMergeConfig(t *testing.T) {
}

b := &Config{
Bootstrap: true,
Datacenter: "dc2",
DataDir: "/tmp/bar",
DNSRecursor: "127.0.0.2:1001",
Bootstrap: true,
Datacenter: "dc2",
DataDir: "/tmp/bar",
DNSRecursor: "127.0.0.2:1001",
DNSConfig: DNSConfig{
NodeTTL: 10 * time.Second,
ServiceTTL: map[string]time.Duration{
"api": 10 * time.Second,
},
AllowStale: true,
MaxStale: 30 * time.Second,
},
Domain: "other",
LogLevel: "info",
NodeName: "baz",
Expand Down
69 changes: 50 additions & 19 deletions command/agent/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const (
// service discovery endpoints using a DNS interface.
type DNSServer struct {
agent *Agent
config *DNSConfig
dnsHandler *dns.ServeMux
dnsServer *dns.Server
dnsServerTCP *dns.Server
Expand All @@ -32,7 +33,7 @@ type DNSServer struct {
}

// NewDNSServer starts a new DNS server to provide an agent interface
func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind, recursor string) (*DNSServer, error) {
func NewDNSServer(agent *Agent, config *DNSConfig, logOutput io.Writer, domain, bind, recursor string) (*DNSServer, error) {
// Make sure domain is FQDN
domain = dns.Fqdn(domain)

Expand All @@ -55,6 +56,7 @@ func NewDNSServer(agent *Agent, logOutput io.Writer, domain, bind, recursor stri
// Create the server
srv := &DNSServer{
agent: agent,
config: config,
dnsHandler: mux,
dnsServer: server,
dnsServerTCP: serverTCP,
Expand Down Expand Up @@ -306,31 +308,41 @@ func (d *DNSServer) nodeLookup(network, datacenter, node string, req, resp *dns.

// Make an RPC request
args := structs.NodeSpecificRequest{
Datacenter: datacenter,
Node: node,
Datacenter: datacenter,
Node: node,
QueryOptions: structs.QueryOptions{AllowStale: d.config.AllowStale},
}
var out structs.IndexedNodeServices
RPC:
if err := d.agent.RPC("Catalog.NodeServices", &args, &out); err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
return
}

// Verify that request is not too stale, redo the request
if args.AllowStale && out.LastContact > d.config.MaxStale {
args.AllowStale = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
goto RPC
}

// If we have no address, return not found!
if out.NodeServices == nil {
resp.SetRcode(req, dns.RcodeNameError)
return
}

// Add the node record
records := d.formatNodeRecord(&out.NodeServices.Node, req.Question[0].Name, qType)
records := d.formatNodeRecord(&out.NodeServices.Node, req.Question[0].Name,
qType, d.config.NodeTTL)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
}

// formatNodeRecord takes a Node and returns an A, AAAA, or CNAME record
func (d *DNSServer) formatNodeRecord(node *structs.Node, qName string, qType uint16) (records []dns.RR) {
func (d *DNSServer) formatNodeRecord(node *structs.Node, qName string, qType uint16, ttl time.Duration) (records []dns.RR) {
// Parse the IP
ip := net.ParseIP(node.Address)
var ipv4 net.IP
Expand All @@ -344,7 +356,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, qName string, qType uin
Name: qName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 0,
Ttl: uint32(ttl / time.Second),
},
A: ip,
}}
Expand All @@ -355,7 +367,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, qName string, qType uin
Name: qName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 0,
Ttl: uint32(ttl / time.Second),
},
AAAA: ip,
}}
Expand All @@ -368,7 +380,7 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, qName string, qType uin
Name: qName,
Rrtype: dns.TypeCNAME,
Class: dns.ClassINET,
Ttl: 0,
Ttl: uint32(ttl / time.Second),
},
Target: dns.Fqdn(node.Address),
}
Expand Down Expand Up @@ -398,24 +410,43 @@ func (d *DNSServer) formatNodeRecord(node *structs.Node, qName string, qType uin
func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req, resp *dns.Msg) {
// Make an RPC request
args := structs.ServiceSpecificRequest{
Datacenter: datacenter,
ServiceName: service,
ServiceTag: tag,
TagFilter: tag != "",
Datacenter: datacenter,
ServiceName: service,
ServiceTag: tag,
TagFilter: tag != "",
QueryOptions: structs.QueryOptions{AllowStale: d.config.AllowStale},
}
var out structs.IndexedCheckServiceNodes
RPC:
if err := d.agent.RPC("Health.ServiceNodes", &args, &out); err != nil {
d.logger.Printf("[ERR] dns: rpc error: %v", err)
resp.SetRcode(req, dns.RcodeServerFailure)
return
}

// Verify that request is not too stale, redo the request
if args.AllowStale && out.LastContact > d.config.MaxStale {
args.AllowStale = false
d.logger.Printf("[WARN] dns: Query results too stale, re-requesting")
goto RPC
}

// If we have no nodes, return not found!
if len(out.Nodes) == 0 {
resp.SetRcode(req, dns.RcodeNameError)
return
}

// Determine the TTL
var ttl time.Duration
if d.config.ServiceTTL != nil {
var ok bool
ttl, ok = d.config.ServiceTTL[service]
if !ok {
ttl = d.config.ServiceTTL["*"]
}
}

// Filter out any service nodes due to health checks
out.Nodes = d.filterServiceNodes(out.Nodes)

Expand All @@ -429,10 +460,10 @@ func (d *DNSServer) serviceLookup(network, datacenter, service, tag string, req,

// Add various responses depending on the request
qType := req.Question[0].Qtype
d.serviceNodeRecords(out.Nodes, req, resp)
d.serviceNodeRecords(out.Nodes, req, resp, ttl)

if qType == dns.TypeSRV {
d.serviceSRVRecords(datacenter, out.Nodes, req, resp)
d.serviceSRVRecords(datacenter, out.Nodes, req, resp, ttl)
}
}

Expand Down Expand Up @@ -464,7 +495,7 @@ func shuffleServiceNodes(nodes structs.CheckServiceNodes) {
}

// serviceNodeRecords is used to add the node records for a service lookup
func (d *DNSServer) serviceNodeRecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg) {
func (d *DNSServer) serviceNodeRecords(nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
qName := req.Question[0].Name
qType := req.Question[0].Qtype
handled := make(map[string]struct{})
Expand All @@ -478,15 +509,15 @@ func (d *DNSServer) serviceNodeRecords(nodes structs.CheckServiceNodes, req, res
handled[addr] = struct{}{}

// Add the node record
records := d.formatNodeRecord(&node.Node, qName, qType)
records := d.formatNodeRecord(&node.Node, qName, qType, ttl)
if records != nil {
resp.Answer = append(resp.Answer, records...)
}
}
}

// serviceARecords is used to add the SRV records for a service lookup
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg) {
func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes, req, resp *dns.Msg, ttl time.Duration) {
handled := make(map[string]struct{})
for _, node := range nodes {
// Avoid duplicate entries, possible if a node has
Expand All @@ -503,7 +534,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
Name: req.Question[0].Name,
Rrtype: dns.TypeSRV,
Class: dns.ClassINET,
Ttl: 0,
Ttl: uint32(ttl / time.Second),
},
Priority: 1,
Weight: 1,
Expand All @@ -513,7 +544,7 @@ func (d *DNSServer) serviceSRVRecords(dc string, nodes structs.CheckServiceNodes
resp.Answer = append(resp.Answer, srvRec)

// Add the extra record
records := d.formatNodeRecord(&node.Node, srvRec.Target, dns.TypeANY)
records := d.formatNodeRecord(&node.Node, srvRec.Target, dns.TypeANY, ttl)
if records != nil {
resp.Extra = append(resp.Extra, records...)
}
Expand Down
Loading

0 comments on commit 93fb12e

Please sign in to comment.