diff --git a/http/handler.go b/http/handler.go index 7f5fbb6ee142..4c24949be290 100644 --- a/http/handler.go +++ b/http/handler.go @@ -131,8 +131,10 @@ func Handler(props *vault.HandlerProperties) http.Handler { mux.Handle("/v1/sys/unseal", handleSysUnseal(core)) mux.Handle("/v1/sys/leader", handleSysLeader(core)) mux.Handle("/v1/sys/health", handleSysHealth(core)) - mux.Handle("/v1/sys/generate-root/attempt", handleRequestForwarding(core, handleSysGenerateRootAttempt(core, vault.GenerateStandardRootTokenStrategy))) - mux.Handle("/v1/sys/generate-root/update", handleRequestForwarding(core, handleSysGenerateRootUpdate(core, vault.GenerateStandardRootTokenStrategy))) + mux.Handle("/v1/sys/generate-root/attempt", handleRequestForwarding(core, + handleAuditNonLogical(core, handleSysGenerateRootAttempt(core, vault.GenerateStandardRootTokenStrategy)))) + mux.Handle("/v1/sys/generate-root/update", handleRequestForwarding(core, + handleAuditNonLogical(core, handleSysGenerateRootUpdate(core, vault.GenerateStandardRootTokenStrategy)))) mux.Handle("/v1/sys/rekey/init", handleRequestForwarding(core, handleSysRekeyInit(core, false))) mux.Handle("/v1/sys/rekey/update", handleRequestForwarding(core, handleSysRekeyUpdate(core, false))) mux.Handle("/v1/sys/rekey/verify", handleRequestForwarding(core, handleSysRekeyVerify(core, false))) @@ -181,6 +183,69 @@ func Handler(props *vault.HandlerProperties) http.Handler { return printablePathCheckHandler } +type copyResponseWriter struct { + wrapped http.ResponseWriter + statusCode int + body *bytes.Buffer +} + +// newCopyResponseWriter returns an initialized newCopyResponseWriter +func newCopyResponseWriter(wrapped http.ResponseWriter) *copyResponseWriter { + w := ©ResponseWriter{ + wrapped: wrapped, + body: new(bytes.Buffer), + statusCode: 200, + } + return w +} + +func (w *copyResponseWriter) Header() http.Header { + return w.wrapped.Header() +} + +func (w *copyResponseWriter) Write(buf []byte) (int, error) { + w.body.Write(buf) + return w.wrapped.Write(buf) +} + +func (w *copyResponseWriter) WriteHeader(code int) { + w.statusCode = code + w.wrapped.WriteHeader(code) +} + +func handleAuditNonLogical(core *vault.Core, h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origBody := new(bytes.Buffer) + reader := ioutil.NopCloser(io.TeeReader(r.Body, origBody)) + r.Body = reader + req, _, status, err := buildLogicalRequestNoAuth(core.PerfStandby(), w, r) + if err != nil || status != 0 { + respondError(w, status, err) + return + } + if origBody != nil { + r.Body = ioutil.NopCloser(origBody) + } + input := &logical.LogInput{ + Request: req, + } + + core.AuditLogger().AuditRequest(r.Context(), input) + cw := newCopyResponseWriter(w) + h.ServeHTTP(cw, r) + data := make(map[string]interface{}) + err = jsonutil.DecodeJSON(cw.body.Bytes(), &data) + if err != nil { + // best effort, ignore + } + httpResp := &logical.HTTPResponse{Data: data, Headers: cw.Header()} + input.Response = logical.HTTPResponseToLogicalResponse(httpResp) + core.AuditLogger().AuditResponse(r.Context(), input) + return + }) + +} + // wrapGenericHandler wraps the handler with an extra layer of handler where // tasks that should be commonly handled for all the requests and/or responses // are performed. diff --git a/http/sys_generate_root_test.go b/http/sys_generate_root_test.go index e79abb9db3bd..d77659a877b5 100644 --- a/http/sys_generate_root_test.go +++ b/http/sys_generate_root_test.go @@ -1,17 +1,21 @@ package http import ( + "context" "encoding/base64" "encoding/hex" "encoding/json" "fmt" + "net" "net/http" "reflect" "testing" "github.com/go-test/deep" + "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/pgpkeys" "github.com/hashicorp/vault/helper/xor" + "github.com/hashicorp/vault/sdk/logical" "github.com/hashicorp/vault/vault" ) @@ -196,16 +200,58 @@ func TestSysGenerateRootAttempt_Cancel(t *testing.T) { } } -func TestSysGenerateRoot_badKey(t *testing.T) { - core, _, token := vault.TestCoreUnsealed(t) +func enableNoopAudit(t *testing.T, token string, core *vault.Core) { + t.Helper() + auditReq := &logical.Request{ + Operation: logical.UpdateOperation, + ClientToken: token, + Path: "sys/audit/noop", + Data: map[string]interface{}{ + "type": "noop", + }, + } + resp, err := core.HandleRequest(namespace.RootContext(context.Background()), auditReq) + if err != nil { + t.Fatal(err) + } + + if resp.IsError() { + t.Fatal(err) + } +} + +func testCoreUnsealedWithAudit(t *testing.T, records **[][]byte) (*vault.Core, [][]byte, string) { + conf := &vault.CoreConfig{ + BuiltinRegistry: vault.NewMockBuiltinRegistry(), + } + vault.AddNoopAudit(conf, records) + core, keys, token := vault.TestCoreUnsealedWithConfig(t, conf) + return core, keys, token +} + +func testServerWithAudit(t *testing.T, records **[][]byte) (net.Listener, string, string, [][]byte) { + core, keys, token := testCoreUnsealedWithAudit(t, records) ln, addr := TestServer(t, core) - defer ln.Close() TestServerAuth(t, addr, token) + enableNoopAudit(t, token, core) + return ln, addr, token, keys +} + +func TestSysGenerateRoot_badKey(t *testing.T) { + var records *[][]byte + ln, addr, token, _ := testServerWithAudit(t, &records) + defer ln.Close() resp := testHttpPut(t, token, addr+"/v1/sys/generate-root/update", map[string]interface{}{ "key": "0123", }) testResponseStatus(t, resp, 400) + + if len(*records) < 3 { + // One record for enabling the noop audit device, two for generate root attempt + t.Fatalf("expected at least 3 audit records, got %d", len(*records)) + } + t.Log(string((*records)[2])) } func TestSysGenerateRoot_ReAttemptUpdate(t *testing.T) { @@ -228,10 +274,9 @@ func TestSysGenerateRoot_ReAttemptUpdate(t *testing.T) { } func TestSysGenerateRoot_Update_OTP(t *testing.T) { - core, keys, token := vault.TestCoreUnsealed(t) - ln, addr := TestServer(t, core) + var records *[][]byte + ln, addr, token, keys := testServerWithAudit(t, &records) defer ln.Close() - TestServerAuth(t, addr, token) resp := testHttpPut(t, token, addr+"/v1/sys/generate-root/attempt", map[string]interface{}{}) var rootGenerationStatus map[string]interface{} @@ -317,6 +362,10 @@ func TestSysGenerateRoot_Update_OTP(t *testing.T) { if !reflect.DeepEqual(actual["data"], expected) { t.Fatalf("\nexpected: %#v\nactual: %#v", expected, actual["data"]) } + + for _, r := range *records { + t.Log(string(r)) + } } func TestSysGenerateRoot_Update_PGP(t *testing.T) { diff --git a/vault/audit.go b/vault/audit.go index 3c54789af353..636619413fc0 100644 --- a/vault/audit.go +++ b/vault/audit.go @@ -506,6 +506,23 @@ func defaultAuditTable() *MountTable { return table } +type AuditLogger interface { + AuditRequest(ctx context.Context, input *logical.LogInput) error + AuditResponse(ctx context.Context, input *logical.LogInput) error +} + +type basicAuditor struct { + c *Core +} + +func (b *basicAuditor) AuditRequest(ctx context.Context, input *logical.LogInput) error { + return b.c.auditBroker.LogRequest(ctx, input, b.c.auditedHeaders) +} + +func (b *basicAuditor) AuditResponse(ctx context.Context, input *logical.LogInput) error { + return b.c.auditBroker.LogResponse(ctx, input, b.c.auditedHeaders) +} + type genericAuditor struct { c *Core mountType string diff --git a/vault/core.go b/vault/core.go index 8f283308d7ce..014c1f36bb20 100644 --- a/vault/core.go +++ b/vault/core.go @@ -2251,3 +2251,7 @@ type BuiltinRegistry interface { Get(name string, pluginType consts.PluginType) (func() (interface{}, error), bool) Keys(pluginType consts.PluginType) []string } + +func (c *Core) AuditLogger() AuditLogger { + return &basicAuditor{c: c} +} diff --git a/vault/testing.go b/vault/testing.go index 6389bcf67b44..6b7267b5cd04 100644 --- a/vault/testing.go +++ b/vault/testing.go @@ -677,7 +677,7 @@ func (n *noopAudit) Salt(ctx context.Context) (*salt.Salt, error) { return salt, nil } -func AddNoopAudit(conf *CoreConfig) { +func AddNoopAudit(conf *CoreConfig, records **[][]byte) { conf.AuditBackends = map[string]audit.Factory{ "noop": func(_ context.Context, config *audit.BackendConfig) (audit.Backend, error) { view := &logical.InmemStorage{} @@ -691,6 +691,9 @@ func AddNoopAudit(conf *CoreConfig) { n.formatter.AuditFormatWriter = &audit.JSONFormatWriter{ SaltFunc: n.Salt, } + if records != nil { + *records = &n.records + } return n, nil }, } @@ -1437,7 +1440,7 @@ func NewTestCluster(t testing.T, base *CoreConfig, opts *TestClusterOptions) *Te addAuditBackend := len(coreConfig.AuditBackends) == 0 if addAuditBackend { - AddNoopAudit(coreConfig) + AddNoopAudit(coreConfig, nil) } if coreConfig.Physical == nil && (opts == nil || opts.PhysicalFactory == nil) {