From 669cdecb8309e8efd0796402f1626d42e93a467d Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Sat, 30 Oct 2021 13:34:23 +1030 Subject: [PATCH] Support TsigProvider for zone transfers --- xfr.go | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/xfr.go b/xfr.go index 43970e64f..6d328e37a 100644 --- a/xfr.go +++ b/xfr.go @@ -18,6 +18,7 @@ type Transfer struct { ReadTimeout time.Duration // net.Conn.SetReadTimeout value for connections, defaults to 2 seconds WriteTimeout time.Duration // net.Conn.SetWriteTimeout value for connections, defaults to 2 seconds TsigSecret map[string]string // Secret(s) for Tsig map[], zonename must be in canonical form (lowercase, fqdn, see RFC 4034 Section 6.2) + TsigProvider TsigProvider // An implementation of the TsigProvider interface. If defined it replaces TsigSecret and is used for all TSIG operations. tsigTimersOnly bool } @@ -224,12 +225,17 @@ func (t *Transfer) ReadMsg() (*Msg, error) { if err := m.Unpack(p); err != nil { return nil, err } - if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil { - if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { - return m, ErrSecret + if ts := m.IsTsig(); ts != nil && (t.TsigSecret != nil || t.TsigProvider != nil) { + if t.TsigProvider != nil { + // Need to work on the original message p, as that was used to calculate the tsig. + err = tsigVerifyProvider(p, t.TsigProvider, t.tsigRequestMAC, t.tsigTimersOnly) + } else { + if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { + return m, ErrSecret + } + // Need to work on the original message p, as that was used to calculate the tsig. + err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) } - // Need to work on the original message p, as that was used to calculate the tsig. - err = TsigVerify(p, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) t.tsigRequestMAC = ts.MAC } return m, err @@ -238,11 +244,15 @@ func (t *Transfer) ReadMsg() (*Msg, error) { // WriteMsg writes a message through the transfer connection t. func (t *Transfer) WriteMsg(m *Msg) (err error) { var out []byte - if ts := m.IsTsig(); ts != nil && t.TsigSecret != nil { - if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { - return ErrSecret + if ts := m.IsTsig(); ts != nil && (t.TsigSecret != nil || t.TsigProvider != nil) { + if t.TsigProvider != nil { + out, t.tsigRequestMAC, err = tsigGenerateProvider(m, t.TsigProvider, t.tsigRequestMAC, t.tsigTimersOnly) + } else { + if _, ok := t.TsigSecret[ts.Hdr.Name]; !ok { + return ErrSecret + } + out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) } - out, t.tsigRequestMAC, err = TsigGenerate(m, t.TsigSecret[ts.Hdr.Name], t.tsigRequestMAC, t.tsigTimersOnly) } else { out, err = m.Pack() }