diff --git a/ais/backend/gcp.go b/ais/backend/gcp.go index 97a99285e9..449ca48a1e 100644 --- a/ais/backend/gcp.go +++ b/ais/backend/gcp.go @@ -462,7 +462,7 @@ func readCredFile() (projectID string) { if err != nil { return } - b, err := io.ReadAll(credFile) + b, err := cos.ReadAll(credFile) credFile.Close() if err != nil { return diff --git a/ais/htcommon.go b/ais/htcommon.go index e110cf0791..841368ef1a 100644 --- a/ais/htcommon.go +++ b/ais/htcommon.go @@ -412,8 +412,13 @@ var ( _ cresv = cresBsumm{} ) -func (res *callResult) read(body io.Reader) { res.bytes, res.err = io.ReadAll(body) } -func (res *callResult) jread(body io.Reader) { res.err = jsoniter.NewDecoder(body).Decode(res.v) } +func (res *callResult) read(body io.Reader, size int64) { + res.bytes, res.err = cos.ReadAllN(body, size) +} + +func (res *callResult) jread(body io.Reader) { + res.err = jsoniter.NewDecoder(body).Decode(res.v) +} func (res *callResult) mread(body io.Reader) { vv, ok := res.v.(msgp.Decodable) diff --git a/ais/htrun.go b/ais/htrun.go index 49509665ca..59e3fdc7c4 100644 --- a/ais/htrun.go +++ b/ais/htrun.go @@ -647,7 +647,7 @@ func (h *htrun) call(args *callArgs, smap *smapX) (res *callResult) { if res.err != nil { res.details = fmt.Sprintf("FATAL: failed to create HTTP request %s %s: %v", args.req.Method, args.req.URL(), res.err) - return + return res } req.Header.Set(apc.HdrCallerID, h.SID()) @@ -663,9 +663,19 @@ func (h *htrun) call(args *callArgs, smap *smapX) (res *callResult) { resp, res.err = client.Do(req) if res.err != nil { res.details = dfltDetail // tcp level, e.g.: connection refused - return + return res + } + + _doResp(args, req, resp, res) + resp.Body.Close() + + if sid != unknownDaemonID { + h.keepalive.heardFrom(sid) } - defer resp.Body.Close() + return res +} + +func _doResp(args *callArgs, req *http.Request, resp *http.Response, res *callResult) { res.status = resp.StatusCode res.header = resp.Header @@ -684,25 +694,14 @@ func (h *htrun) call(args *callArgs, smap *smapX) (res *callResult) { return } - // read and decode via call result value (`cresv`), if provided + // read and decode via call-result-value (`cresv`), if provided; // othwerwise, read and return bytes for the caller to unmarshal if args.cresv != nil { res.v = args.cresv.newV() args.cresv.read(res, resp.Body) - if res.err != nil { - return - } } else { - res.read(resp.Body) - if res.err != nil { - return - } + res.read(resp.Body, resp.ContentLength) } - - if sid != unknownDaemonID { - h.keepalive.heardFrom(sid) - } - return } // diff --git a/ais/proxy.go b/ais/proxy.go index f9c68523cb..fa001e3cfa 100644 --- a/ais/proxy.go +++ b/ais/proxy.go @@ -7,7 +7,6 @@ package ais import ( "errors" "fmt" - "io" "net" "net/http" "net/url" @@ -3007,7 +3006,7 @@ func (p *proxy) dsortHandler(w http.ResponseWriter, r *http.Request) { case http.MethodPost: // - validate request, check input_bck and output_bck // - start dsort - body, err := io.ReadAll(r.Body) + body, err := cos.ReadAllN(r.Body, r.ContentLength) if err != nil { p.writeErrStatusf(w, r, http.StatusInternalServerError, "failed to receive dsort request: %v", err) return diff --git a/ais/prxdl.go b/ais/prxdl.go index abb3a30952..426725ca3a 100644 --- a/ais/prxdl.go +++ b/ais/prxdl.go @@ -6,7 +6,6 @@ package ais import ( "fmt" - "io" "net/http" "net/url" "strconv" @@ -87,7 +86,7 @@ func (p *proxy) httpdlpost(w http.ResponseWriter, r *http.Request) { jobID := dload.PrefixJobID + cos.GenUUID() // prefix to visually differentiate vs. xaction IDs - body, err := io.ReadAll(r.Body) + body, err := cos.ReadAllN(r.Body, r.ContentLength) if err != nil { p.writeErrStatusf(w, r, http.StatusInternalServerError, "failed to receive download request: %v", err) return diff --git a/ais/prxetl.go b/ais/prxetl.go index 792683a72f..9a4a405aca 100644 --- a/ais/prxetl.go +++ b/ais/prxetl.go @@ -5,7 +5,6 @@ package ais import ( - "io" "net/http" "net/url" "reflect" @@ -100,7 +99,7 @@ func (p *proxy) handleETLPut(w http.ResponseWriter, r *http.Request) { return } - b, err := io.ReadAll(r.Body) + b, err := cos.ReadAll(r.Body) if err != nil { p.writeErr(w, r, err) return diff --git a/ais/prxnotif_internal_test.go b/ais/prxnotif_internal_test.go index 70a814affa..4bed5dcbec 100644 --- a/ais/prxnotif_internal_test.go +++ b/ais/prxnotif_internal_test.go @@ -6,7 +6,6 @@ package ais import ( "bytes" - "io" "net/http" "net/http/httptest" "time" @@ -138,7 +137,7 @@ var _ = Describe("Notifications xaction test", func() { writer := httptest.NewRecorder() n.handler(writer, req) resp := writer.Result() - respBody, _ := io.ReadAll(resp.Body) + respBody, _ := cos.ReadAllN(resp.Body, resp.ContentLength) resp.Body.Close() Expect(resp.StatusCode).To(BeEquivalentTo(expectedStatus)) return respBody diff --git a/ais/s3/presigned.go b/ais/s3/presigned.go index 3e3d9e5712..519c64a272 100644 --- a/ais/s3/presigned.go +++ b/ais/s3/presigned.go @@ -70,7 +70,7 @@ func (pts *PresignedReq) Do(client *http.Client) (*PresignedResp, error) { } defer resp.BodyR.Close() - output, err := io.ReadAll(resp.BodyR) + output, err := cos.ReadAll(resp.BodyR) if err != nil { return &PresignedResp{StatusCode: http.StatusBadRequest}, fmt.Errorf("failed to read response body: %v", err) } @@ -111,7 +111,7 @@ func (pts *PresignedReq) DoReader(client *http.Client) (*PresignedResp, error) { } if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { - output, _ := io.ReadAll(resp.Body) + output, _ := cos.ReadAll(resp.Body) resp.Body.Close() return &PresignedResp{StatusCode: resp.StatusCode}, fmt.Errorf("invalid status: %d, output: %s", resp.StatusCode, string(output)) } diff --git a/ais/tgtetl.go b/ais/tgtetl.go index c493bcf08d..8da06eec64 100644 --- a/ais/tgtetl.go +++ b/ais/tgtetl.go @@ -6,7 +6,6 @@ package ais import ( "fmt" - "io" "net/http" "net/url" "strconv" @@ -56,7 +55,7 @@ func (t *target) handleETLPut(w http.ResponseWriter, r *http.Request) { return } - b, err := io.ReadAll(r.Body) + b, err := cos.ReadAll(r.Body) if err != nil { t.writeErr(w, r, err) return diff --git a/ais/tgts3mpt.go b/ais/tgts3mpt.go index 70ff39e484..2eb8889b54 100644 --- a/ais/tgts3mpt.go +++ b/ais/tgts3mpt.go @@ -215,7 +215,7 @@ func (t *target) completeMpt(w http.ResponseWriter, r *http.Request, items []str return } - output, err := io.ReadAll(r.Body) + output, err := cos.ReadAllN(r.Body, r.ContentLength) if err != nil { s3.WriteErr(w, r, err, http.StatusBadRequest) return diff --git a/api/client.go b/api/client.go index 51e1a2fae5..4decf75b3e 100644 --- a/api/client.go +++ b/api/client.go @@ -279,7 +279,7 @@ func (reqParams *ReqParams) readStr(resp *http.Response, out *string) error { if err := reqParams.checkResp(resp); err != nil { return err } - b, err := io.ReadAll(resp.Body) + b, err := cos.ReadAllN(resp.Body, resp.ContentLength) if err != nil { return fmt.Errorf("failed to read response: %w", err) } @@ -351,7 +351,7 @@ func (reqParams *ReqParams) checkResp(resp *http.Response) error { } } - b, _ := io.ReadAll(resp.Body) + b, _ := cos.ReadAllN(resp.Body, resp.ContentLength) if len(b) == 0 { if resp.StatusCode == http.StatusServiceUnavailable { msg := fmt.Sprintf("[%s]: starting up, please try again later...", http.StatusText(http.StatusServiceUnavailable)) diff --git a/api/etl.go b/api/etl.go index 8ccf813572..cf7233e19d 100644 --- a/api/etl.go +++ b/api/etl.go @@ -62,7 +62,7 @@ func ETLGetInitMsg(params BaseParams, etlName string) (etl.InitMsg, error) { } defer cos.Close(r) - b, err := io.ReadAll(r) + b, err := cos.ReadAll(r) if err != nil { return nil, fmt.Errorf("failed to read response: %w", err) } diff --git a/api/stats.go b/api/stats.go index ee9cacf5f9..4cf5efc844 100644 --- a/api/stats.go +++ b/api/stats.go @@ -5,7 +5,6 @@ package api import ( - "io" "net/http" "net/url" @@ -112,7 +111,7 @@ func GetAnyStats(bp BaseParams, sid, what string) (out []byte, err error) { if err != nil { return nil, err } - out, err = io.ReadAll(resp.Body) + out, err = cos.ReadAllN(resp.Body, resp.ContentLength) cos.DrainReader(resp.Body) resp.Body.Close() FreeRp(reqParams) diff --git a/bench/tools/aisloader/client.go b/bench/tools/aisloader/client.go index 66c8b18da6..806ec48c2d 100644 --- a/bench/tools/aisloader/client.go +++ b/bench/tools/aisloader/client.go @@ -557,7 +557,7 @@ func readDiscard(r *http.Response, tag, cksumType string) (int64, string, error) cksumValue string ) if r.StatusCode >= http.StatusBadRequest { - bytes, err := io.ReadAll(r.Body) + bytes, err := cos.ReadAll(r.Body) if err == nil { return 0, "", fmt.Errorf("bad status %d from %s, response: %s", r.StatusCode, tag, string(bytes)) } diff --git a/bench/tools/aisloader/run.go b/bench/tools/aisloader/run.go index f4691c5e6c..e4bd5bbd3f 100644 --- a/bench/tools/aisloader/run.go +++ b/bench/tools/aisloader/run.go @@ -27,7 +27,6 @@ import ( "errors" "flag" "fmt" - "io" "math" "math/rand/v2" "os" @@ -816,7 +815,7 @@ func _init(p *params) (err error) { if err != nil { return err } - etlSpec, err := io.ReadAll(fh) + etlSpec, err := cos.ReadAll(fh) fh.Close() if err != nil { return err diff --git a/cmd/authn/aisreq.go b/cmd/authn/aisreq.go index 6b8b8582aa..b62c39a856 100644 --- a/cmd/authn/aisreq.go +++ b/cmd/authn/aisreq.go @@ -7,7 +7,6 @@ package main import ( "bytes" "fmt" - "io" "net/http" "sync" "time" @@ -125,7 +124,7 @@ func (m *mgr) call(method, proxyURL, path string, injson []byte, tag string) err resp, err := client.Do(req) if resp != nil { if resp.Body != nil { - msg, _ = io.ReadAll(resp.Body) + msg, _ = cos.ReadAll(resp.Body) resp.Body.Close() } } diff --git a/cmn/cos/io.go b/cmn/cos/io.go index d0892d2752..5327eff050 100644 --- a/cmn/cos/io.go +++ b/cmn/cos/io.go @@ -5,20 +5,13 @@ package cos import ( - "bufio" "bytes" - cryptorand "crypto/rand" - "errors" - "fmt" "io" "math" "os" - "os/user" "path/filepath" - "strconv" "github.com/NVIDIA/aistore/cmn/debug" - "github.com/NVIDIA/aistore/cmn/nlog" ) // POSIX permissions @@ -180,13 +173,6 @@ var ( _ ReadOpenCloser = (*ByteHandle)(nil) ) -// including "unexpecting EOF" to accommodate unsized streaming and -// early termination of the other side (prior to sending the first byte) -func IsEOF(err error) bool { - return err == io.EOF || err == io.ErrUnexpectedEOF || - errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) -} - /////////////// // nopReader // /////////////// @@ -448,326 +434,3 @@ func (w *Buffer) WriteTo2(dst io.Writer) (err error) { _, err = w.b.WriteTo(dst) return err } - -/////////////////////// -// misc file and dir // -/////////////////////// - -// ExpandPath replaces common abbreviations in file path (eg. `~` with absolute -// path to the current user home directory) and cleans the path. -func ExpandPath(path string) string { - if path == "" || path[0] != '~' { - return filepath.Clean(path) - } - if len(path) > 1 && path[1] != '/' { - return filepath.Clean(path) - } - - currentUser, err := user.Current() - if err != nil { - return filepath.Clean(path) - } - return filepath.Clean(filepath.Join(currentUser.HomeDir, path[1:])) -} - -// CreateDir creates directory if does not exist. -// If the directory already exists returns nil. -func CreateDir(dir string) error { - return os.MkdirAll(dir, configDirMode) -} - -// CreateFile creates a new write-only (O_WRONLY) file with default cos.PermRWR permissions. -// NOTE: if the file pathname doesn't exist it'll be created. -// NOTE: if the file already exists it'll be also silently truncated. -func CreateFile(fqn string) (*os.File, error) { - if err := CreateDir(filepath.Dir(fqn)); err != nil { - return nil, err - } - return os.OpenFile(fqn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, PermRWR) -} - -// (creates destination directory if doesn't exist) -func Rename(src, dst string) (err error) { - err = os.Rename(src, dst) - if err == nil { - return nil - } - if !os.IsNotExist(err) { - if os.IsExist(err) { - if finfo, errN := os.Stat(dst); errN == nil && finfo.IsDir() { - // [design tradeoff] keeping objects under (e.g.) their respective sha256 - // would eliminate this one, in part - return fmt.Errorf("destination %q is a (virtual) directory", dst) - } - } - return err - } - // create and retry (slow path) - err = CreateDir(filepath.Dir(dst)) - if err == nil { - err = os.Rename(src, dst) - } - return err -} - -// RemoveFile removes path; returns nil upon success or if the path does not exist. -func RemoveFile(path string) (err error) { - err = os.Remove(path) - if os.IsNotExist(err) { - err = nil - } - return -} - -// and computes checksum if requested -func CopyFile(src, dst string, buf []byte, cksumType string) (written int64, cksum *CksumHash, err error) { - var srcFile, dstFile *os.File - if srcFile, err = os.Open(src); err != nil { - return - } - if dstFile, err = CreateFile(dst); err != nil { - nlog.Errorln("Failed to create", dst+":", err) - Close(srcFile) - return - } - written, cksum, err = CopyAndChecksum(dstFile, srcFile, buf, cksumType) - Close(srcFile) - defer func() { - if err == nil { - return - } - if nestedErr := RemoveFile(dst); nestedErr != nil { - nlog.Errorf("Nested (%v): failed to remove %s, err: %v", err, dst, nestedErr) - } - }() - if err != nil { - nlog.Errorln("Failed to copy", src, "=>", dst+":", err) - Close(dstFile) - return - } - if err = FlushClose(dstFile); err != nil { - nlog.Errorln("Failed to flush and close", dst+":", err) - } - return -} - -func SaveReaderSafe(tmpfqn, fqn string, reader io.Reader, buf []byte, cksumType string, size int64) (cksum *CksumHash, - err error) { - if cksum, err = SaveReader(tmpfqn, reader, buf, cksumType, size); err != nil { - return - } - if err = Rename(tmpfqn, fqn); err != nil { - os.Remove(tmpfqn) - } - return -} - -// Saves the reader directly to `fqn`, checksums if requested -func SaveReader(fqn string, reader io.Reader, buf []byte, cksumType string, size int64) (cksum *CksumHash, err error) { - var ( - written int64 - file, erc = CreateFile(fqn) - writer = WriterOnly{file} // Hiding `ReadFrom` for `*os.File` introduced in Go1.15. - ) - if erc != nil { - return nil, erc - } - defer func() { - if err != nil { - os.Remove(fqn) - } - }() - - if size >= 0 { - reader = io.LimitReader(reader, size) - } - written, cksum, err = CopyAndChecksum(writer, reader, buf, cksumType) - erc = file.Close() - - if err != nil { - err = fmt.Errorf("failed to save to %q: %w", fqn, err) - return - } - if size >= 0 && written != size { - err = fmt.Errorf("wrong size when saving to %q: expected %d, got %d", fqn, size, written) - return - } - if erc != nil { - err = fmt.Errorf("failed to close %q: %w", fqn, erc) - return - } - return -} - -// a slightly modified excerpt from https://github.com/golang/go/blob/master/src/io/io.go#L407 -// - regular streaming copy with `io.WriteTo` and `io.ReaderFrom` not checked and not used -// - buffer _must_ be provided -// - see also: WriterOnly comment (above) -func CopyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { - for { - nr, er := src.Read(buf) - if nr > 0 { - nw, ew := dst.Write(buf[0:nr]) - if ew != nil { - if nw > 0 && nw <= nr { - written += int64(nw) - } - err = ew - break - } - if nw < 0 || nw > nr { - err = errors.New("cos.CopyBuffer: invalid write") - break - } - written += int64(nw) - if nr != nw { - err = io.ErrShortWrite - break - } - } - if er != nil { - if er != io.EOF { - err = er - } - break - } - } - return written, err -} - -// Read only the first line of a file. -// Do not use for big files: it reads all the content and then extracts the first -// line. Use for files that may contains a few lines with trailing EOL -func ReadOneLine(filename string) (string, error) { - var line string - err := ReadLines(filename, func(l string) error { - line = l - return io.EOF - }) - return line, err -} - -// Read only the first line of a file and return it as uint64 -// Do not use for big files: it reads all the content and then extracts the first -// line. Use for files that may contains a few lines with trailing EOL -func ReadOneUint64(filename string) (uint64, error) { - line, err := ReadOneLine(filename) - if err != nil { - return 0, err - } - val, err := strconv.ParseUint(line, 10, 64) - return val, err -} - -// Read only the first line of a file and return it as int64 -// Do not use for big files: it reads all the content and then extracts the first -// line. Use for files that may contains a few lines with trailing EOL -func ReadOneInt64(filename string) (int64, error) { - line, err := ReadOneLine(filename) - if err != nil { - return 0, err - } - val, err := strconv.ParseInt(line, 10, 64) - return val, err -} - -// Read a file line by line and call a callback for each line until the file -// ends or a callback returns io.EOF -func ReadLines(filename string, cb func(string) error) error { - b, err := os.ReadFile(filename) - if err != nil { - return err - } - - lineReader := bufio.NewReader(bytes.NewBuffer(b)) - for { - line, _, err := lineReader.ReadLine() - if err != nil { - if err == io.EOF { - err = nil - } - return err - } - - if err := cb(string(line)); err != nil { - if err != io.EOF { - return err - } - break - } - } - return nil -} - -// CopyAndChecksum reads from `r` and writes to `w`; returns num bytes copied and checksum, or error -func CopyAndChecksum(w io.Writer, r io.Reader, buf []byte, cksumType string) (n int64, cksum *CksumHash, err error) { - debug.Assert(w != io.Discard || buf == nil) // io.Discard is io.ReaderFrom - - if cksumType == ChecksumNone || cksumType == "" { - n, err = io.CopyBuffer(w, r, buf) - return n, nil, err - } - - cksum = NewCksumHash(cksumType) - var mw io.Writer = cksum.H - if w != io.Discard { - mw = NewWriterMulti(cksum.H, w) - } - n, err = io.CopyBuffer(mw, r, buf) - cksum.Finalize() - return n, cksum, err -} - -// ChecksumBytes computes checksum of given bytes using additional buffer. -func ChecksumBytes(b []byte, cksumType string) (cksum *Cksum, err error) { - _, hash, err := CopyAndChecksum(io.Discard, bytes.NewReader(b), nil, cksumType) - if err != nil { - return nil, err - } - return &hash.Cksum, nil -} - -// DrainReader reads and discards all the data from a reader. -// No need for `io.CopyBuffer` as `io.Discard` has efficient `io.ReaderFrom` implementation. -func DrainReader(r io.Reader) { - _, err := io.Copy(io.Discard, r) - if err == nil || IsEOF(err) { - return - } - debug.AssertNoErr(err) -} - -// FloodWriter writes `n` random bytes to provided writer. -func FloodWriter(w io.Writer, n int64) error { - _, err := io.CopyN(w, cryptorand.Reader, n) - return err -} - -func Close(closer io.Closer) { - err := closer.Close() - debug.AssertNoErr(err) -} - -func FlushClose(file *os.File) (err error) { - err = fflush(file) - debug.AssertNoErr(err) - err = file.Close() - debug.AssertNoErr(err) - return -} - -// NOTE: -// - file.Close() is implementation dependent as far as flushing dirty buffers; -// - journaling filesystems, such as xfs, generally provide better guarantees but, again, not 100% -// - see discussion at https://lwn.net/Articles/788938; -// - going forward, some sort of `rename_barrier()` would be a much better alternative -// - doesn't work in testing environment - currently disabled, see #1141 and comments - -const fsyncDisabled = true - -func fflush(file *os.File) (err error) { - if fsyncDisabled { - return - } - return file.Sync() -} diff --git a/cmn/cos/ioutils.go b/cmn/cos/ioutils.go new file mode 100644 index 0000000000..38e91de50e --- /dev/null +++ b/cmn/cos/ioutils.go @@ -0,0 +1,378 @@ +// Package cos provides common low-level types and utilities for all aistore projects +/* + * Copyright (c) 2018-2024, NVIDIA CORPORATION. All rights reserved. + */ +package cos + +import ( + "bufio" + "bytes" + cryptorand "crypto/rand" + "errors" + "fmt" + "io" + "os" + "os/user" + "path/filepath" + "strconv" + + "github.com/NVIDIA/aistore/cmn/debug" + "github.com/NVIDIA/aistore/cmn/nlog" +) + +// instead of os.ReadAll +func ReadAllN(r io.Reader, size int64) (b []byte, err error) { + switch size { + case 0: + case ContentLengthUnknown: + buf := bytes.NewBuffer(nil) + _, err = io.Copy(buf, r) + b = buf.Bytes() + default: + buf := bytes.NewBuffer(make([]byte, 0, size)) + _, err = io.Copy(buf, r) + b = buf.Bytes() + } + debug.Func(func() { + n, _ := io.Copy(io.Discard, r) + debug.Assert(n == 0) + }) + return b, err +} + +func ReadAll(r io.Reader) ([]byte, error) { + buf := &bytes.Buffer{} + _, err := io.Copy(buf, r) + + // DEBUG + // b := buf.Bytes() + // nlog.ErrorDepth(1, ">>>>>> len =", len(b)) + + return buf.Bytes(), err +} + +// including "unexpecting EOF" to accommodate unsized streaming and +// early termination of the other side (prior to sending the first byte) +func IsEOF(err error) bool { + return err == io.EOF || err == io.ErrUnexpectedEOF || + errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) +} + +// ExpandPath replaces common abbreviations in file path (eg. `~` with absolute +// path to the current user home directory) and cleans the path. +func ExpandPath(path string) string { + if path == "" || path[0] != '~' { + return filepath.Clean(path) + } + if len(path) > 1 && path[1] != '/' { + return filepath.Clean(path) + } + + currentUser, err := user.Current() + if err != nil { + return filepath.Clean(path) + } + return filepath.Clean(filepath.Join(currentUser.HomeDir, path[1:])) +} + +// CreateDir creates directory if does not exist. +// If the directory already exists returns nil. +func CreateDir(dir string) error { + return os.MkdirAll(dir, configDirMode) +} + +// CreateFile creates a new write-only (O_WRONLY) file with default cos.PermRWR permissions. +// NOTE: if the file pathname doesn't exist it'll be created. +// NOTE: if the file already exists it'll be also silently truncated. +func CreateFile(fqn string) (*os.File, error) { + if err := CreateDir(filepath.Dir(fqn)); err != nil { + return nil, err + } + return os.OpenFile(fqn, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, PermRWR) +} + +// (creates destination directory if doesn't exist) +func Rename(src, dst string) (err error) { + err = os.Rename(src, dst) + if err == nil { + return nil + } + if !os.IsNotExist(err) { + if os.IsExist(err) { + if finfo, errN := os.Stat(dst); errN == nil && finfo.IsDir() { + // [design tradeoff] keeping objects under (e.g.) their respective sha256 + // would eliminate this one, in part + return fmt.Errorf("destination %q is a (virtual) directory", dst) + } + } + return err + } + // create and retry (slow path) + err = CreateDir(filepath.Dir(dst)) + if err == nil { + err = os.Rename(src, dst) + } + return err +} + +// RemoveFile removes path; returns nil upon success or if the path does not exist. +func RemoveFile(path string) (err error) { + err = os.Remove(path) + if os.IsNotExist(err) { + err = nil + } + return +} + +// and computes checksum if requested +func CopyFile(src, dst string, buf []byte, cksumType string) (written int64, cksum *CksumHash, err error) { + var srcFile, dstFile *os.File + if srcFile, err = os.Open(src); err != nil { + return + } + if dstFile, err = CreateFile(dst); err != nil { + nlog.Errorln("Failed to create", dst+":", err) + Close(srcFile) + return + } + written, cksum, err = CopyAndChecksum(dstFile, srcFile, buf, cksumType) + Close(srcFile) + defer func() { + if err == nil { + return + } + if nestedErr := RemoveFile(dst); nestedErr != nil { + nlog.Errorf("Nested (%v): failed to remove %s, err: %v", err, dst, nestedErr) + } + }() + if err != nil { + nlog.Errorln("Failed to copy", src, "=>", dst+":", err) + Close(dstFile) + return + } + if err = FlushClose(dstFile); err != nil { + nlog.Errorln("Failed to flush and close", dst+":", err) + } + return +} + +func SaveReaderSafe(tmpfqn, fqn string, reader io.Reader, buf []byte, cksumType string, size int64) (cksum *CksumHash, + err error) { + if cksum, err = SaveReader(tmpfqn, reader, buf, cksumType, size); err != nil { + return + } + if err = Rename(tmpfqn, fqn); err != nil { + os.Remove(tmpfqn) + } + return +} + +// Saves the reader directly to `fqn`, checksums if requested +func SaveReader(fqn string, reader io.Reader, buf []byte, cksumType string, size int64) (cksum *CksumHash, err error) { + var ( + written int64 + file, erc = CreateFile(fqn) + writer = WriterOnly{file} // Hiding `ReadFrom` for `*os.File` introduced in Go1.15. + ) + if erc != nil { + return nil, erc + } + defer func() { + if err != nil { + os.Remove(fqn) + } + }() + + if size >= 0 { + reader = io.LimitReader(reader, size) + } + written, cksum, err = CopyAndChecksum(writer, reader, buf, cksumType) + erc = file.Close() + + if err != nil { + err = fmt.Errorf("failed to save to %q: %w", fqn, err) + return + } + if size >= 0 && written != size { + err = fmt.Errorf("wrong size when saving to %q: expected %d, got %d", fqn, size, written) + return + } + if erc != nil { + err = fmt.Errorf("failed to close %q: %w", fqn, erc) + return + } + return +} + +// a slightly modified excerpt from https://github.com/golang/go/blob/master/src/io/io.go#L407 +// - regular streaming copy with `io.WriteTo` and `io.ReaderFrom` not checked and not used +// - buffer _must_ be provided +// - see also: WriterOnly comment (above) +func CopyBuffer(dst io.Writer, src io.Reader, buf []byte) (written int64, err error) { + for { + nr, er := src.Read(buf) + if nr > 0 { + nw, ew := dst.Write(buf[0:nr]) + if ew != nil { + if nw > 0 && nw <= nr { + written += int64(nw) + } + err = ew + break + } + if nw < 0 || nw > nr { + err = errors.New("cos.CopyBuffer: invalid write") + break + } + written += int64(nw) + if nr != nw { + err = io.ErrShortWrite + break + } + } + if er != nil { + if er != io.EOF { + err = er + } + break + } + } + return written, err +} + +// Read only the first line of a file. +// Do not use for big files: it reads all the content and then extracts the first +// line. Use for files that may contains a few lines with trailing EOL +func ReadOneLine(filename string) (string, error) { + var line string + err := ReadLines(filename, func(l string) error { + line = l + return io.EOF + }) + return line, err +} + +// Read only the first line of a file and return it as uint64 +// Do not use for big files: it reads all the content and then extracts the first +// line. Use for files that may contains a few lines with trailing EOL +func ReadOneUint64(filename string) (uint64, error) { + line, err := ReadOneLine(filename) + if err != nil { + return 0, err + } + val, err := strconv.ParseUint(line, 10, 64) + return val, err +} + +// Read only the first line of a file and return it as int64 +// Do not use for big files: it reads all the content and then extracts the first +// line. Use for files that may contains a few lines with trailing EOL +func ReadOneInt64(filename string) (int64, error) { + line, err := ReadOneLine(filename) + if err != nil { + return 0, err + } + val, err := strconv.ParseInt(line, 10, 64) + return val, err +} + +// Read a file line by line and call a callback for each line until the file +// ends or a callback returns io.EOF +func ReadLines(filename string, cb func(string) error) error { + b, err := os.ReadFile(filename) + if err != nil { + return err + } + + lineReader := bufio.NewReader(bytes.NewBuffer(b)) + for { + line, _, err := lineReader.ReadLine() + if err != nil { + if err == io.EOF { + err = nil + } + return err + } + + if err := cb(string(line)); err != nil { + if err != io.EOF { + return err + } + break + } + } + return nil +} + +// CopyAndChecksum reads from `r` and writes to `w`; returns num bytes copied and checksum, or error +func CopyAndChecksum(w io.Writer, r io.Reader, buf []byte, cksumType string) (n int64, cksum *CksumHash, err error) { + debug.Assert(w != io.Discard || buf == nil) // io.Discard is io.ReaderFrom + + if cksumType == ChecksumNone || cksumType == "" { + n, err = io.CopyBuffer(w, r, buf) + return n, nil, err + } + + cksum = NewCksumHash(cksumType) + var mw io.Writer = cksum.H + if w != io.Discard { + mw = NewWriterMulti(cksum.H, w) + } + n, err = io.CopyBuffer(mw, r, buf) + cksum.Finalize() + return n, cksum, err +} + +// ChecksumBytes computes checksum of given bytes using additional buffer. +func ChecksumBytes(b []byte, cksumType string) (cksum *Cksum, err error) { + _, hash, err := CopyAndChecksum(io.Discard, bytes.NewReader(b), nil, cksumType) + if err != nil { + return nil, err + } + return &hash.Cksum, nil +} + +// DrainReader reads and discards all the data from a reader. +// No need for `io.CopyBuffer` as `io.Discard` has efficient `io.ReaderFrom` implementation. +func DrainReader(r io.Reader) { + _, err := io.Copy(io.Discard, r) + if err == nil || IsEOF(err) { + return + } + debug.AssertNoErr(err) +} + +// FloodWriter writes `n` random bytes to provided writer. +func FloodWriter(w io.Writer, n int64) error { + _, err := io.CopyN(w, cryptorand.Reader, n) + return err +} + +func Close(closer io.Closer) { + err := closer.Close() + debug.AssertNoErr(err) +} + +func FlushClose(file *os.File) (err error) { + err = fflush(file) + debug.AssertNoErr(err) + err = file.Close() + debug.AssertNoErr(err) + return +} + +// NOTE: +// - file.Close() is implementation dependent as far as flushing dirty buffers; +// - journaling filesystems, such as xfs, generally provide better guarantees but, again, not 100% +// - see discussion at https://lwn.net/Articles/788938; +// - going forward, some sort of `rename_barrier()` would be a much better alternative +// - doesn't work in testing environment - currently disabled, see #1141 and comments + +const fsyncDisabled = true + +func fflush(file *os.File) (err error) { + if fsyncDisabled { + return + } + return file.Sync() +} diff --git a/cmn/http.go b/cmn/http.go index 90ae40094c..2ae38656b8 100644 --- a/cmn/http.go +++ b/cmn/http.go @@ -135,7 +135,7 @@ func ParseURL(path string, itemsPresent []string, itemsAfter int, splitAfter boo func ReadBytes(r *http.Request) (b []byte, err error) { var e error - b, e = io.ReadAll(r.Body) + b, e = cos.ReadAllN(r.Body, r.ContentLength) if e != nil { err = fmt.Errorf("failed to read %s request, err: %v", r.Method, e) if e == io.EOF { diff --git a/cmn/jsp/io.go b/cmn/jsp/io.go index 52f463a031..69671dd410 100644 --- a/cmn/jsp/io.go +++ b/cmn/jsp/io.go @@ -153,7 +153,7 @@ func Decode(reader io.ReadCloser, v any, opts Options, tag string) (checksum *co // We have already parsed `v` but there is still the possibility that `\n` remains // not read. Therefore, we read it to include it into the final checksum. var b []byte - if b, err = io.ReadAll(r); err != nil { + if b, err = cos.ReadAll(r); err != nil { return } // To be sure that this is exactly the case... diff --git a/cmn/k8s/client.go b/cmn/k8s/client.go index 3d6840027f..b8668cfc40 100644 --- a/cmn/k8s/client.go +++ b/cmn/k8s/client.go @@ -6,11 +6,11 @@ package k8s import ( "context" - "io" "os" "strings" "github.com/NVIDIA/aistore/api/env" + "github.com/NVIDIA/aistore/cmn/cos" "github.com/NVIDIA/aistore/cmn/debug" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -191,7 +191,7 @@ func (c *defaultClient) Logs(podName string) ([]byte, error) { return nil, err } defer logStream.Close() - return io.ReadAll(logStream) + return cos.ReadAll(logStream) } func (c *defaultClient) CheckMetricsAvailability() error { diff --git a/ec/manager.go b/ec/manager.go index 31db9fbba0..036b633c2f 100644 --- a/ec/manager.go +++ b/ec/manager.go @@ -161,7 +161,7 @@ func (mgr *Manager) recvRequest(hdr *transport.ObjHdr, objReader io.Reader, err // command requests should not have a body, but if it has, // the body must be drained to avoid errors if hdr.ObjAttrs.Size != 0 { - if _, err := io.ReadAll(objReader); err != nil { + if _, err := cos.ReadAll(objReader); err != nil { nlog.Errorf("failed to read request body: %v", err) return err } diff --git a/ec/metafile.go b/ec/metafile.go index 4fcdd1effc..9464cefa76 100644 --- a/ec/metafile.go +++ b/ec/metafile.go @@ -62,7 +62,7 @@ func LoadMetadata(fqn string) (*Metadata, error) { } func MetaFromReader(reader io.Reader) (*Metadata, error) { - b, err := io.ReadAll(reader) + b, err := cos.ReadAll(reader) if err != nil { return nil, err } diff --git a/ext/dsort/bcast.go b/ext/dsort/bcast.go index 6b3689c2f3..12e76e8596 100644 --- a/ext/dsort/bcast.go +++ b/ext/dsort/bcast.go @@ -5,7 +5,6 @@ package dsort import ( - "io" "net/http" "net/url" "sync" @@ -67,7 +66,7 @@ func call(reqArgs *cmn.HreqArgs) response { if err != nil { return response{err: err, statusCode: http.StatusInternalServerError} } - out, err := io.ReadAll(resp.Body) + out, err := cos.ReadAll(resp.Body) cos.Close(resp.Body) return response{res: out, err: err, statusCode: resp.StatusCode} } diff --git a/ext/dsort/dsort.go b/ext/dsort/dsort.go index 4187976942..6ed135d0ef 100644 --- a/ext/dsort/dsort.go +++ b/ext/dsort/dsort.go @@ -858,7 +858,7 @@ func (m *Manager) _do(reqArgs *cmn.HreqArgs, tsi *meta.Snode, act string) error } if resp.StatusCode != http.StatusOK { var b []byte - b, err = io.ReadAll(resp.Body) + b, err = cos.ReadAll(resp.Body) if err == nil { err = fmt.Errorf("%s: %s failed to %s: %s", core.T, m.ManagerUUID, act, strings.TrimSuffix(string(b), "\n")) } else { diff --git a/ext/dsort/handler.go b/ext/dsort/handler.go index 59e440d1c3..1ce45cb240 100644 --- a/ext/dsort/handler.go +++ b/ext/dsort/handler.go @@ -6,7 +6,6 @@ package dsort import ( "fmt" - "io" "net/http" "net/url" "regexp" @@ -442,7 +441,7 @@ func tinitHandler(w http.ResponseWriter, r *http.Request) { } var ( pars *parsedReqSpec - b, err = io.ReadAll(r.Body) + b, err = cos.ReadAll(r.Body) ) if err != nil { cmn.WriteErr(w, r, fmt.Errorf("[dsort]: failed to receive request: %w", err)) diff --git a/ext/dsort/shard/key.go b/ext/dsort/shard/key.go index 02dd65886a..184ac9cba2 100644 --- a/ext/dsort/shard/key.go +++ b/ext/dsort/shard/key.go @@ -109,7 +109,7 @@ func (ke *contentKeyExtractor) ExtractKey(ske *SingleKeyExtractor) (any, error) if ske == nil { return nil, nil } - b, err := io.ReadAll(ske.buf) + b, err := cos.ReadAll(ske.buf) ske.buf = nil if err != nil { return nil, err diff --git a/ext/etl/comm_internal_test.go b/ext/etl/comm_internal_test.go index 2f4cefdf11..df06cb8657 100644 --- a/ext/etl/comm_internal_test.go +++ b/ext/etl/comm_internal_test.go @@ -7,7 +7,6 @@ package etl import ( cryptorand "crypto/rand" "fmt" - "io" "net/http" "net/http/httptest" "os" @@ -134,7 +133,7 @@ var _ = Describe("CommunicatorTest", func() { Expect(err).NotTo(HaveOccurred()) defer resp.Body.Close() - b, err := io.ReadAll(resp.Body) + b, err := cos.ReadAll(resp.Body) Expect(err).NotTo(HaveOccurred()) Expect(len(b)).To(Equal(len(transformData))) Expect(b).To(Equal(transformData)) diff --git a/memsys/c_test.go b/memsys/c_test.go index 935a7b2f65..b7c61e4e31 100644 --- a/memsys/c_test.go +++ b/memsys/c_test.go @@ -14,6 +14,7 @@ import ( "sync" "testing" + "github.com/NVIDIA/aistore/cmn/cos" "github.com/NVIDIA/aistore/memsys" "github.com/NVIDIA/aistore/tools/tassert" "github.com/NVIDIA/aistore/tools/tlog" @@ -59,7 +60,7 @@ func TestSGLStressN(t *testing.T) { // read SGL from destination and compare with the original var bufW []byte - bufW, err = io.ReadAll(memsys.NewReader(sglW)) + bufW, err = cos.ReadAll(memsys.NewReader(sglW)) tassert.CheckFatal(t, err) for j := range objsize { if bufW[j] != bufR[j] { diff --git a/memsys/iosgl.go b/memsys/iosgl.go index b32a363e52..ac62991896 100644 --- a/memsys/iosgl.go +++ b/memsys/iosgl.go @@ -273,7 +273,7 @@ func (z *SGL) _readAt(b []byte, roffin int64) (n int, roff int64, err error) { } // ReadAll is a strictly _convenience_ method as it performs heap allocation. -// Still, it's an optimized alternative to the generic io.ReadAll which +// Still, it's an optimized alternative to the generic cos.ReadAll which // normally returns err == nil (and not io.EOF) upon successful reading until EOF. // ReadAll always returns err == nil. func (z *SGL) ReadAll() (b []byte) { diff --git a/tools/tetl/etl.go b/tools/tetl/etl.go index c8fdf529cb..eb4089e910 100644 --- a/tools/tetl/etl.go +++ b/tools/tetl/etl.go @@ -7,7 +7,6 @@ package tetl import ( "bytes" "fmt" - "io" "net/http" "os" "strings" @@ -92,7 +91,7 @@ func GetTransformYaml(etlName string) ([]byte, error) { } defer resp.Body.Close() - b, err := io.ReadAll(resp.Body) + b, err := cos.ReadAll(resp.Body) if err != nil { return nil, err } diff --git a/transport/obj_test.go b/transport/obj_test.go index 886b941cfa..1c4d5d6cae 100644 --- a/transport/obj_test.go +++ b/transport/obj_test.go @@ -103,7 +103,7 @@ func TestMain(t *testing.M) { func Example_headers() { f := func(_ http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) + body, err := cos.ReadAll(r.Body) if err != nil { panic(err) } @@ -194,7 +194,7 @@ func sendText(stream *transport.Stream, txt1, txt2 string) { func Example_obj() { receive := func(hdr *transport.ObjHdr, objReader io.Reader, err error) error { cos.Assert(err == nil) - object, err := io.ReadAll(objReader) + object, err := cos.ReadAll(objReader) if err != nil { panic(err) }