diff --git a/api/sys_raft.go b/api/sys_raft.go index cbf3a2020038..685717401dd3 100644 --- a/api/sys_raft.go +++ b/api/sys_raft.go @@ -272,24 +272,93 @@ func (c *Sys) RaftSnapshot(snapWriter io.Writer) error { } // RaftSnapshotRestore reads the snapshot from the io.Reader and installs that -// snapshot, returning the cluster to the state defined by it. +// snapshot, returning the cluster to the state defined by it. This avoids the use of +// RawRequestWithContext which copies the body (leading to possible OOMs) for retrying func (c *Sys) RaftSnapshotRestore(snapReader io.Reader, force bool) error { path := "/v1/sys/storage/raft/snapshot" if force { path = "/v1/sys/storage/raft/snapshot-force" } - r := c.c.NewRequest("POST", path) - r.Body = snapReader + r := c.c.NewRequest(http.MethodPost, path) + r.URL.RawQuery = r.Params.Encode() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - resp, err := c.c.RawRequestWithContext(ctx, r) + req, err := http.NewRequest(http.MethodPost, r.URL.RequestURI(), snapReader) if err != nil { return err } + + req.URL.User = r.URL.User + req.URL.Scheme = r.URL.Scheme + req.URL.Host = r.URL.Host + req.Host = r.URL.Host + + if r.Headers != nil { + for header, vals := range r.Headers { + for _, val := range vals { + req.Header.Add(header, val) + } + } + } + + if len(r.ClientToken) != 0 { + req.Header.Set(consts.AuthHeaderName, r.ClientToken) + } + + if len(r.WrapTTL) != 0 { + req.Header.Set("X-Vault-Wrap-TTL", r.WrapTTL) + } + + if len(r.MFAHeaderVals) != 0 { + for _, mfaHeaderVal := range r.MFAHeaderVals { + req.Header.Add("X-Vault-MFA", mfaHeaderVal) + } + } + + if r.PolicyOverride { + req.Header.Set("X-Vault-Policy-Override", "true") + } + + var result *Response + resp, err := c.c.config.HttpClient.Do(req) defer resp.Body.Close() + if err != nil { + return err + } + + if resp == nil { + return nil + } + + // Check for a redirect, only allowing for a single redirect + if resp.StatusCode == 301 || resp.StatusCode == 302 || resp.StatusCode == 307 { + // Parse the updated location + respLoc, err := resp.Location() + if err != nil { + return err + } + + // Ensure a protocol downgrade doesn't happen + if req.URL.Scheme == "https" && respLoc.Scheme != "https" { + return fmt.Errorf("redirect would cause protocol downgrade") + } + + // Update the request + req.URL = respLoc + + // Retry the request + resp, err = c.c.config.HttpClient.Do(req) + if err != nil { + return err + } + } + + result = &Response{Response: resp} + if err := result.Error(); err != nil { + return err + } + return nil } diff --git a/changelog/14269.txt b/changelog/14269.txt new file mode 100644 index 000000000000..529b7c626429 --- /dev/null +++ b/changelog/14269.txt @@ -0,0 +1,3 @@ +```release-note:bug + api/sys/raft: Update RaftSnapshotRestore to use net/http client allowing bodies larger than allocated memory to be streamed +```