diff --git a/sdk/storage/azdatalake/file/chunkwriting.go b/sdk/storage/azdatalake/file/chunkwriting.go new file mode 100644 index 000000000000..289b042d1646 --- /dev/null +++ b/sdk/storage/azdatalake/file/chunkwriting.go @@ -0,0 +1,193 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package file + +import ( + "bytes" + "context" + "errors" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "io" + "sync" +) + +// chunkWriter provides methods to upload chunks that represent a file to a server. +// This allows us to provide a local implementation that fakes the server for hermetic testing. +type chunkWriter interface { + AppendData(context.Context, int64, io.ReadSeekCloser, *AppendDataOptions) (AppendDataResponse, error) + FlushData(context.Context, int64, *FlushDataOptions) (FlushDataResponse, error) +} + +// bufferManager provides an abstraction for the management of buffers. +// this is mostly for testing purposes, but does allow for different implementations without changing the algorithm. +type bufferManager[T ~[]byte] interface { + // Acquire returns the channel that contains the pool of buffers. + Acquire() <-chan T + + // Release releases the buffer back to the pool for reuse/cleanup. + Release(T) + + // Grow grows the number of buffers, up to the predefined max. + // It returns the total number of buffers or an error. + // No error is returned if the number of buffers has reached max. + // This is called only from the reading goroutine. + Grow() (int, error) + + // Free cleans up all buffers. + Free() +} + +// copyFromReader copies a source io.Reader to file storage using concurrent uploads. +func copyFromReader[T ~[]byte](ctx context.Context, src io.Reader, dst chunkWriter, options UploadStreamOptions, getBufferManager func(maxBuffers int, bufferSize int64) bufferManager[T]) error { + options.setDefaults() + actualSize := int64(0) + wg := sync.WaitGroup{} // Used to know when all outgoing chunks have finished processing + errCh := make(chan error, 1) // contains the first error encountered during processing + var err error + + buffers := getBufferManager(int(options.Concurrency), options.ChunkSize) + defer buffers.Free() + + // this controls the lifetime of the uploading goroutines. + // if an error is encountered, cancel() is called which will terminate all uploads. + // NOTE: the ordering is important here. cancel MUST execute before + // cleaning up the buffers so that any uploading goroutines exit first, + // releasing their buffers back to the pool for cleanup. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // This goroutine grabs a buffer, reads from the stream into the buffer, + // then creates a goroutine to upload/stage the chunk. + for chunkNum := uint32(0); true; chunkNum++ { + var buffer T + select { + case buffer = <-buffers.Acquire(): + // got a buffer + default: + // no buffer available; allocate a new buffer if possible + if _, err := buffers.Grow(); err != nil { + return err + } + + // either grab the newly allocated buffer or wait for one to become available + buffer = <-buffers.Acquire() + } + + var n int + n, err = io.ReadFull(src, buffer) + + if n > 0 { + // some data was read, upload it + wg.Add(1) // We're posting a buffer to be sent + + // NOTE: we must pass chunkNum as an arg to our goroutine else + // it's captured by reference and can change underneath us! + go func(chunkNum uint32) { + // Upload the outgoing chunk, matching the number of bytes read + offset := int64(chunkNum) * options.ChunkSize + appendDataOpts := options.getAppendDataOptions() + actualSize += int64(len(buffer[:n])) + _, err := dst.AppendData(ctx, offset, streaming.NopCloser(bytes.NewReader(buffer[:n])), appendDataOpts) + if err != nil { + select { + case errCh <- err: + // error was set + default: + // some other error is already set + } + cancel() + } + buffers.Release(buffer) // The goroutine reading from the stream can reuse this buffer now + + // signal that the chunk has been staged. + // we MUST do this after attempting to write to errCh + // to avoid it racing with the reading goroutine. + wg.Done() + }(chunkNum) + } else { + // nothing was read so the buffer is empty, send it back for reuse/clean-up. + buffers.Release(buffer) + } + + if err != nil { // The reader is done, no more outgoing buffers + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + // these are expected errors, we don't surface those + err = nil + } else { + // some other error happened, terminate any outstanding uploads + cancel() + } + break + } + } + + wg.Wait() // Wait for all outgoing chunks to complete + + if err != nil { + // there was an error reading from src, favor this error over any error during staging + return err + } + + select { + case err = <-errCh: + // there was an error during staging + return err + default: + // no error was encountered + } + + // All chunks uploaded, return nil error + flushOpts := options.getFlushDataOptions() + _, err = dst.FlushData(ctx, actualSize, flushOpts) + return err +} + +// mmbPool implements the bufferManager interface. +// it uses anonymous memory mapped files for buffers. +// don't use this type directly, use newMMBPool() instead. +type mmbPool struct { + buffers chan mmb + count int + max int + size int64 +} + +func newMMBPool(maxBuffers int, bufferSize int64) bufferManager[mmb] { + return &mmbPool{ + buffers: make(chan mmb, maxBuffers), + max: maxBuffers, + size: bufferSize, + } +} + +func (pool *mmbPool) Acquire() <-chan mmb { + return pool.buffers +} + +func (pool *mmbPool) Grow() (int, error) { + if pool.count < pool.max { + buffer, err := newMMB(pool.size) + if err != nil { + return 0, err + } + pool.buffers <- buffer + pool.count++ + } + return pool.count, nil +} + +func (pool *mmbPool) Release(buffer mmb) { + pool.buffers <- buffer +} + +func (pool *mmbPool) Free() { + for i := 0; i < pool.count; i++ { + buffer := <-pool.buffers + buffer.delete() + } + pool.count = 0 +} diff --git a/sdk/storage/azdatalake/file/client.go b/sdk/storage/azdatalake/file/client.go index 5909aa475aea..0fae799fb965 100644 --- a/sdk/storage/azdatalake/file/client.go +++ b/sdk/storage/azdatalake/file/client.go @@ -7,11 +7,16 @@ package file import ( + "bytes" "context" + "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/datalakeerror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/base" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/exported" @@ -19,9 +24,12 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/path" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/sas" + "io" "net/http" "net/url" + "os" "strings" + "sync" "time" ) @@ -254,26 +262,6 @@ func (f *Client) SetExpiry(ctx context.Context, expiryType SetExpiryType, o *Set return resp, err } -//// Upload uploads data to a file. -//func (f *Client) Upload(ctx context.Context) { -// -//} -// -//// Append appends data to a file. -//func (f *Client) Append(ctx context.Context) { -// -//} -// -//// Flush flushes previous uploaded data to a file. -//func (f *Client) Flush(ctx context.Context) { -// -//} -// -//// Download downloads data from a file. -//func (f *Client) Download(ctx context.Context) { -// -//} - // SetAccessControl sets the owner, owning group, and permissions for a file or directory (dfs1). func (f *Client) SetAccessControl(ctx context.Context, options *SetAccessControlOptions) (SetAccessControlResponse, error) { opts, lac, mac, err := path.FormatSetAccessControlOptions(options) @@ -360,3 +348,159 @@ func (f *Client) GetSASURL(permissions sas.FilePermissions, expiry time.Time, o return endpoint, nil } + +func (f *Client) AppendData(ctx context.Context, offset int64, body io.ReadSeekCloser, options *AppendDataOptions) (AppendDataResponse, error) { + appendDataOptions, leaseAccessConditions, httpsHeaders, cpkInfo, err := options.format(offset, body) + if err != nil { + return AppendDataResponse{}, err + } + + resp, err := f.generatedFileClientWithDFS().AppendData(ctx, body, appendDataOptions, httpsHeaders, leaseAccessConditions, cpkInfo) + return resp, exported.ConvertToDFSError(err) +} + +func (f *Client) FlushData(ctx context.Context, offset int64, options *FlushDataOptions) (FlushDataResponse, error) { + flushDataOpts, modifiedAccessConditions, leaseAccessConditions, httpHeaderOpts, cpkInfoOpts, err := options.format(offset) + if err != nil { + return FlushDataResponse{}, err + } + + resp, err := f.generatedFileClientWithDFS().FlushData(ctx, flushDataOpts, httpHeaderOpts, leaseAccessConditions, modifiedAccessConditions, cpkInfoOpts) + return resp, exported.ConvertToDFSError(err) +} + +// Concurrent Upload Functions ----------------------------------------------------------------------------------------- + +// uploadFromReader uploads a buffer in chunks to an Azure file. +func (f *Client) uploadFromReader(ctx context.Context, reader io.ReaderAt, actualSize int64, o *uploadFromReaderOptions) error { + if actualSize > MaxFileSize { + return errors.New("buffer is too large to upload to a file") + } + if o.ChunkSize == 0 { + o.ChunkSize = MaxUpdateRangeBytes + } + + if log.Should(exported.EventUpload) { + urlParts, err := azdatalake.ParseURL(f.DFSURL()) + if err == nil { + log.Writef(exported.EventUpload, "file name %s actual size %v chunk-size %v chunk-count %v", + urlParts.PathName, actualSize, o.ChunkSize, ((actualSize-1)/o.ChunkSize)+1) + } + } + + progress := int64(0) + progressLock := &sync.Mutex{} + + err := shared.DoBatchTransfer(ctx, &shared.BatchTransferOptions{ + OperationName: "uploadFromReader", + TransferSize: actualSize, + ChunkSize: o.ChunkSize, + Concurrency: o.Concurrency, + Operation: func(ctx context.Context, offset int64, chunkSize int64) error { + // This function is called once per file range. + // It is passed this file's offset within the buffer and its count of bytes + // Prepare to read the proper range/section of the buffer + if chunkSize < o.ChunkSize { + // this is the last file range. Its actual size might be less + // than the calculated size due to rounding up of the payload + // size to fit in a whole number of chunks. + chunkSize = actualSize - offset + } + var body io.ReadSeeker = io.NewSectionReader(reader, offset, chunkSize) + if o.Progress != nil { + chunkProgress := int64(0) + body = streaming.NewRequestProgress(streaming.NopCloser(body), + func(bytesTransferred int64) { + diff := bytesTransferred - chunkProgress + chunkProgress = bytesTransferred + progressLock.Lock() // 1 goroutine at a time gets progress report + progress += diff + o.Progress(progress) + progressLock.Unlock() + }) + } + + uploadRangeOptions := o.getAppendDataOptions() + _, err := f.AppendData(ctx, offset, streaming.NopCloser(body), uploadRangeOptions) + return exported.ConvertToDFSError(err) + }, + }) + + if err != nil { + return exported.ConvertToDFSError(err) + } + // All appends were successful, call to flush + flushOpts := o.getFlushDataOptions() + _, err = f.FlushData(ctx, actualSize, flushOpts) + return exported.ConvertToDFSError(err) +} + +// UploadBuffer uploads a buffer in chunks to an Azure file. +func (f *Client) UploadBuffer(ctx context.Context, buffer []byte, options *UploadBufferOptions) error { + uploadOptions := uploadFromReaderOptions{} + if options != nil { + uploadOptions = *options + } + return exported.ConvertToDFSError(f.uploadFromReader(ctx, bytes.NewReader(buffer), int64(len(buffer)), &uploadOptions)) +} + +// UploadFile uploads a file in chunks to an Azure file. +func (f *Client) UploadFile(ctx context.Context, file *os.File, options *UploadFileOptions) error { + stat, err := file.Stat() + if err != nil { + return err + } + uploadOptions := uploadFromReaderOptions{} + if options != nil { + uploadOptions = *options + } + return exported.ConvertToDFSError(f.uploadFromReader(ctx, file, stat.Size(), &uploadOptions)) +} + +// UploadStream copies the file held in io.Reader to the file at fileClient. +// A Context deadline or cancellation will cause this to error. +func (f *Client) UploadStream(ctx context.Context, body io.Reader, options *UploadStreamOptions) error { + if options == nil { + options = &UploadStreamOptions{} + } + + err := copyFromReader(ctx, body, f, *options, newMMBPool) + return exported.ConvertToDFSError(err) +} + +// DownloadStream reads a range of bytes from a blob. The response also includes the blob's properties and metadata. +// For more information, see https://docs.microsoft.com/rest/api/storageservices/get-blob. +func (f *Client) DownloadStream(ctx context.Context, o *DownloadStreamOptions) (DownloadStreamResponse, error) { + if o == nil { + o = &DownloadStreamOptions{} + } + opts := o.format() + resp, err := f.blobClient().DownloadStream(ctx, opts) + newResp := FormatDownloadStreamResponse(&resp) + fullResp := DownloadStreamResponse{ + client: f, + DownloadResponse: newResp, + getInfo: httpGetterInfo{Range: o.Range, ETag: newResp.ETag}, + cpkInfo: o.CPKInfo, + cpkScope: o.CPKScopeInfo, + } + + return fullResp, exported.ConvertToDFSError(err) +} + +// DownloadBuffer downloads an Azure blob to a buffer with parallel. +func (f *Client) DownloadBuffer(ctx context.Context, buffer []byte, o *DownloadBufferOptions) (int64, error) { + opts := o.format() + val, err := f.blobClient().DownloadBuffer(ctx, shared.NewBytesWriter(buffer), opts) + return val, exported.ConvertToDFSError(err) +} + +// DownloadFile downloads an Azure blob to a local file. +// The file would be truncated if the size doesn't match. +func (f *Client) DownloadFile(ctx context.Context, file *os.File, o *DownloadFileOptions) (int64, error) { + opts := o.format() + val, err := f.blobClient().DownloadFile(ctx, file, opts) + return val, exported.ConvertToDFSError(err) +} + +// TODO: add undelete diff --git a/sdk/storage/azdatalake/file/client_test.go b/sdk/storage/azdatalake/file/client_test.go index f9b2ccdacf71..09e50c55615c 100644 --- a/sdk/storage/azdatalake/file/client_test.go +++ b/sdk/storage/azdatalake/file/client_test.go @@ -7,16 +7,25 @@ package file_test import ( + "bytes" "context" + "crypto/md5" + "encoding/binary" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/datalakeerror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/file" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/testcommon" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/sas" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "hash/crc64" + "io" + "math/rand" "net/http" + "os" "testing" "time" ) @@ -2302,3 +2311,786 @@ func (s *RecordedTestSuite) TestRenameFileIfETagMatchFalse() { _require.NotNil(err) testcommon.ValidateErrorCode(_require, err, datalakeerror.SourceConditionNotMet) } + +func (s *RecordedTestSuite) TestFileUploadDownloadStream() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 100 * 1024 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadStream(context.Background(), streaming.NopCloser(bytes.NewReader(content)), &file.UploadStreamOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) + +} + +func (s *RecordedTestSuite) TestFileUploadDownloadSmallStream() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadStream(context.Background(), streaming.NopCloser(bytes.NewReader(content)), &file.UploadStreamOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestFileUploadTinyStream() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 4 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadStream(context.Background(), streaming.NopCloser(bytes.NewReader(content)), &file.UploadStreamOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestFileUploadFile() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 100 * 1024 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + // create local file + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + err = os.WriteFile("testFile", content, 0644) + _require.NoError(err) + + defer func() { + err = os.Remove("testFile") + _require.NoError(err) + }() + + fh, err := os.Open("testFile") + _require.NoError(err) + + defer func(fh *os.File) { + err := fh.Close() + _require.NoError(err) + }(fh) + + hash := md5.New() + _, err = io.Copy(hash, fh) + _require.NoError(err) + contentMD5 := hash.Sum(nil) + + err = fClient.UploadFile(context.Background(), fh, &file.UploadFileOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestSmallFileUploadFile() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + // create local file + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + err = os.WriteFile("testFile", content, 0644) + _require.NoError(err) + + defer func() { + err = os.Remove("testFile") + _require.NoError(err) + }() + + fh, err := os.Open("testFile") + _require.NoError(err) + + defer func(fh *os.File) { + err := fh.Close() + _require.NoError(err) + }(fh) + + hash := md5.New() + _, err = io.Copy(hash, fh) + _require.NoError(err) + contentMD5 := hash.Sum(nil) + + err = fClient.UploadFile(context.Background(), fh, &file.UploadFileOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestTinyFileUploadFile() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + // create local file + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + err = os.WriteFile("testFile", content, 0644) + _require.NoError(err) + + defer func() { + err = os.Remove("testFile") + _require.NoError(err) + }() + + fh, err := os.Open("testFile") + _require.NoError(err) + + defer func(fh *os.File) { + err := fh.Close() + _require.NoError(err) + }(fh) + + hash := md5.New() + _, err = io.Copy(hash, fh) + _require.NoError(err) + contentMD5 := hash.Sum(nil) + + err = fClient.UploadFile(context.Background(), fh, &file.UploadFileOptions{ + ChunkSize: 2, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestFileUploadBuffer() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 100 * 1024 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestFileUploadSmallBuffer() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestFileAppendAndFlushData() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + srcFileName := "src" + testcommon.GenerateFileName(testName) + + srcFClient, err := testcommon.GetFileClient(filesystemName, srcFileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := srcFClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + contentSize := 1024 * 8 // 8KB + rsc, _ := testcommon.GenerateData(contentSize) + + _, err = srcFClient.AppendData(context.Background(), 0, rsc, nil) + _require.NoError(err) + + _, err = srcFClient.FlushData(context.Background(), int64(contentSize), nil) + _require.NoError(err) + + gResp2, err := srcFClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, int64(contentSize)) +} + +func (s *RecordedTestSuite) TestFileAppendAndFlushDataWithValidation() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + srcFileName := "src" + testcommon.GenerateFileName(testName) + + srcFClient, err := testcommon.GetFileClient(filesystemName, srcFileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := srcFClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + contentSize := 1024 * 8 // 8KB + content := make([]byte, contentSize) + body := bytes.NewReader(content) + rsc := streaming.NopCloser(body) + contentCRC64 := crc64.Checksum(content, shared.CRC64Table) + + opts := &file.AppendDataOptions{ + TransactionalValidation: file.TransferValidationTypeComputeCRC64(), + } + putResp, err := srcFClient.AppendData(context.Background(), 0, rsc, opts) + _require.Nil(err) + // _require.Equal(putResp.RawResponse.StatusCode, 201) + _require.NotNil(putResp.ContentCRC64) + _require.EqualValues(binary.LittleEndian.Uint64(putResp.ContentCRC64), contentCRC64) + + _, err = srcFClient.FlushData(context.Background(), int64(contentSize), nil) + _require.NoError(err) + + gResp2, err := srcFClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, int64(contentSize)) +} + +func (s *RecordedTestSuite) TestFileDownloadFile() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 100 * 1024 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + + destFileName := "BigFile-downloaded.bin" + destFile, err := os.Create(destFileName) + _require.NoError(err) + defer func(name string) { + err = os.Remove(name) + _require.NoError(err) + }(destFileName) + defer func(destFile *os.File) { + err = destFile.Close() + _require.NoError(err) + }(destFile) + + cnt, err := fClient.DownloadFile(context.Background(), destFile, &file.DownloadFileOptions{ + ChunkSize: 10 * 1024 * 1024, + Concurrency: 5, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + + hash := md5.New() + _, err = io.Copy(hash, destFile) + _require.NoError(err) + downloadedContentMD5 := hash.Sum(nil) + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) +} + +func (s *RecordedTestSuite) TestFileUploadDownloadSmallFile() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + // create local file + _, content := testcommon.GenerateData(int(fileSize)) + srcFileName := "testFileUpload" + err = os.WriteFile(srcFileName, content, 0644) + _require.NoError(err) + defer func() { + err = os.Remove(srcFileName) + _require.NoError(err) + }() + fh, err := os.Open(srcFileName) + _require.NoError(err) + defer func(fh *os.File) { + err := fh.Close() + _require.NoError(err) + }(fh) + + srcHash := md5.New() + _, err = io.Copy(srcHash, fh) + _require.NoError(err) + contentMD5 := srcHash.Sum(nil) + + err = fClient.UploadFile(context.Background(), fh, &file.UploadFileOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + destFileName := "SmallFile-downloaded.bin" + destFile, err := os.Create(destFileName) + _require.NoError(err) + defer func(name string) { + err = os.Remove(name) + _require.NoError(err) + }(destFileName) + defer func(destFile *os.File) { + err = destFile.Close() + _require.NoError(err) + }(destFile) + + cnt, err := fClient.DownloadFile(context.Background(), destFile, &file.DownloadFileOptions{ + ChunkSize: 2 * 1024, + Concurrency: 5, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + + destHash := md5.New() + _, err = io.Copy(destHash, destFile) + _require.NoError(err) + downloadedContentMD5 := destHash.Sum(nil) + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) +} + +func (s *RecordedTestSuite) TestFileUploadDownloadWithProgress() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + _, content := testcommon.GenerateData(int(fileSize)) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + bytesUploaded := int64(0) + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + Progress: func(bytesTransferred int64) { + _require.GreaterOrEqual(bytesTransferred, bytesUploaded) + bytesUploaded = bytesTransferred + }, + }) + _require.NoError(err) + _require.Equal(bytesUploaded, fileSize) + + destBuffer := make([]byte, fileSize) + bytesDownloaded := int64(0) + cnt, err := fClient.DownloadBuffer(context.Background(), destBuffer, &file.DownloadBufferOptions{ + ChunkSize: 2 * 1024, + Concurrency: 5, + Progress: func(bytesTransferred int64) { + _require.GreaterOrEqual(bytesTransferred, bytesDownloaded) + bytesDownloaded = bytesTransferred + }, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + _require.Equal(bytesDownloaded, fileSize) + + downloadedMD5Value := md5.Sum(destBuffer) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) +} + +func (s *RecordedTestSuite) TestFileDownloadBuffer() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 100 * 1024 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + + destBuffer := make([]byte, fileSize) + cnt, err := fClient.DownloadBuffer(context.Background(), destBuffer, &file.DownloadBufferOptions{ + ChunkSize: 10 * 1024 * 1024, + Concurrency: 5, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + + downloadedMD5Value := md5.Sum(destBuffer) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) +} + +// TODO tests all uploads/downloads with other opts diff --git a/sdk/storage/azdatalake/file/constants.go b/sdk/storage/azdatalake/file/constants.go index 2345c88d547b..7dd13f5de226 100644 --- a/sdk/storage/azdatalake/file/constants.go +++ b/sdk/storage/azdatalake/file/constants.go @@ -7,6 +7,7 @@ package file import ( + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/path" ) @@ -36,3 +37,14 @@ const ( CopyStatusTypeAborted CopyStatusType = path.CopyStatusTypeAborted CopyStatusTypeFailed CopyStatusType = path.CopyStatusTypeFailed ) + +// TransferValidationType abstracts the various mechanisms used to verify a transfer. +type TransferValidationType = exported.TransferValidationType + +// TransferValidationTypeCRC64 is a TransferValidationType used to provide a precomputed crc64. +type TransferValidationTypeCRC64 = exported.TransferValidationTypeCRC64 + +// TransferValidationTypeComputeCRC64 is a TransferValidationType that indicates a CRC64 should be computed during transfer. +func TransferValidationTypeComputeCRC64() TransferValidationType { + return exported.TransferValidationTypeComputeCRC64() +} diff --git a/sdk/storage/azdatalake/file/mmf_unix.go b/sdk/storage/azdatalake/file/mmf_unix.go new file mode 100644 index 000000000000..4c8ed223dbae --- /dev/null +++ b/sdk/storage/azdatalake/file/mmf_unix.go @@ -0,0 +1,38 @@ +//go:build go1.18 && (linux || darwin || dragonfly || freebsd || openbsd || netbsd || solaris || aix) +// +build go1.18 +// +build linux darwin dragonfly freebsd openbsd netbsd solaris aix + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package file + +import ( + "fmt" + "os" + "syscall" +) + +// mmb is a memory mapped buffer +type mmb []byte + +// newMMB creates a new memory mapped buffer with the specified size +func newMMB(size int64) (mmb, error) { + prot, flags := syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANON|syscall.MAP_PRIVATE + addr, err := syscall.Mmap(-1, 0, int(size), prot, flags) + if err != nil { + return nil, os.NewSyscallError("Mmap", err) + } + return mmb(addr), nil +} + +// delete cleans up the memory mapped buffer +func (m *mmb) delete() { + err := syscall.Munmap(*m) + *m = nil + if err != nil { + // if we get here, there is likely memory corruption. + // please open an issue https://github.com/Azure/azure-sdk-for-go/issues + panic(fmt.Sprintf("Munmap error: %v", err)) + } +} diff --git a/sdk/storage/azdatalake/file/mmf_windows.go b/sdk/storage/azdatalake/file/mmf_windows.go new file mode 100644 index 000000000000..b59e6b415776 --- /dev/null +++ b/sdk/storage/azdatalake/file/mmf_windows.go @@ -0,0 +1,56 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package file + +import ( + "fmt" + "os" + "reflect" + "syscall" + "unsafe" +) + +// mmb is a memory mapped buffer +type mmb []byte + +// newMMB creates a new memory mapped buffer with the specified size +func newMMB(size int64) (mmb, error) { + const InvalidHandleValue = ^uintptr(0) // -1 + + prot, access := uint32(syscall.PAGE_READWRITE), uint32(syscall.FILE_MAP_WRITE) + hMMF, err := syscall.CreateFileMapping(syscall.Handle(InvalidHandleValue), nil, prot, uint32(size>>32), uint32(size&0xffffffff), nil) + if err != nil { + return nil, os.NewSyscallError("CreateFileMapping", err) + } + defer func() { + _ = syscall.CloseHandle(hMMF) + }() + + addr, err := syscall.MapViewOfFile(hMMF, access, 0, 0, uintptr(size)) + if err != nil { + return nil, os.NewSyscallError("MapViewOfFile", err) + } + + m := mmb{} + h := (*reflect.SliceHeader)(unsafe.Pointer(&m)) + h.Data = addr + h.Len = int(size) + h.Cap = h.Len + return m, nil +} + +// delete cleans up the memory mapped buffer +func (m *mmb) delete() { + addr := uintptr(unsafe.Pointer(&(([]byte)(*m)[0]))) + *m = mmb{} + err := syscall.UnmapViewOfFile(addr) + if err != nil { + // if we get here, there is likely memory corruption. + // please open an issue https://github.com/Azure/azure-sdk-for-go/issues + panic(fmt.Sprintf("UnmapViewOfFile error: %v", err)) + } +} diff --git a/sdk/storage/azdatalake/file/models.go b/sdk/storage/azdatalake/file/models.go index a4f8b994ff1d..1f628860bded 100644 --- a/sdk/storage/azdatalake/file/models.go +++ b/sdk/storage/azdatalake/file/models.go @@ -7,15 +7,33 @@ package file import ( + "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/generated" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/path" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/shared" + "io" "net/http" "strconv" "time" ) +const ( + _1MiB = 1024 * 1024 + CountToEnd = 0 + + // MaxUpdateRangeBytes indicates the maximum number of bytes that can be updated in a call to Client.UploadRange. + MaxUpdateRangeBytes = 4 * 1024 * 1024 // 4MiB + + // MaxFileSize indicates the maximum size of the file allowed. + MaxFileSize = 4 * 1024 * 1024 * 1024 * 1024 // 4 TiB + + // DefaultDownloadChunkSize is default chunk size + DefaultDownloadChunkSize = int64(4 * 1024 * 1024) // 4MiB +) + // CreateOptions contains the optional parameters when calling the Create operation. dfs endpoint. type CreateOptions struct { // AccessConditions contains parameters for accessing the file. @@ -167,6 +185,401 @@ func (o *RemoveAccessControlOptions) format(ACL string) (*generated.PathClientSe }, mode } +type HTTPRange = exported.HTTPRange + +// uploadFromReaderOptions identifies options used by the UploadBuffer and UploadFile functions. +type uploadFromReaderOptions struct { + // ChunkSize specifies the chunk size to use in bytes; the default (and maximum size) is MaxUpdateRangeBytes. + ChunkSize int64 + // Progress is a function that is invoked periodically as bytes are sent to the FileClient. + // Note that the progress reporting is not always increasing; it can go down when retrying a request. + Progress func(bytesTransferred int64) + // Concurrency indicates the maximum number of chunks to upload in parallel (default is 5) + Concurrency uint16 + // AccessConditions contains optional parameters to access leased entity. + AccessConditions *AccessConditions + // HTTPHeaders contains the optional path HTTP headers to set when the file is created. + HTTPHeaders *HTTPHeaders + // CPKInfo contains optional parameters to perform encryption using customer-provided key. + CPKInfo *CPKInfo +} + +// UploadStreamOptions provides set of configurations for Client.UploadStream operation. +type UploadStreamOptions struct { + // ChunkSize specifies the chunk size to use in bytes; the default (and maximum size) is MaxUpdateRangeBytes. + ChunkSize int64 + // Concurrency indicates the maximum number of chunks to upload in parallel (default is 5) + Concurrency uint16 + // AccessConditions contains optional parameters to access leased entity. + AccessConditions *AccessConditions + // HTTPHeaders contains the optional path HTTP headers to set when the file is created. + HTTPHeaders *HTTPHeaders + // CPKInfo contains optional parameters to perform encryption using customer-provided key. + CPKInfo *CPKInfo +} + +// UploadBufferOptions provides set of configurations for Client.UploadBuffer operation. +type UploadBufferOptions = uploadFromReaderOptions + +// UploadFileOptions provides set of configurations for Client.UploadFile operation. +type UploadFileOptions = uploadFromReaderOptions + +// FlushDataOptions contains the optional parameters for the Client.FlushData method. +type FlushDataOptions struct { + AccessConditions *AccessConditions + // HTTPHeaders contains the optional path HTTP headers to set when the file is created. + // CPKInfo contains optional parameters to perform encryption using customer-provided key. + CPKInfo *CPKInfo + HTTPHeaders *HTTPHeaders + Close *bool + RetainUncommittedData *bool +} + +func (o *FlushDataOptions) format(offset int64) (*generated.PathClientFlushDataOptions, *generated.ModifiedAccessConditions, *generated.LeaseAccessConditions, *generated.PathHTTPHeaders, *generated.CPKInfo, error) { + defaultRetainUncommitted := false + defaultClose := false + contentLength := int64(0) + + var httpHeaderOpts *generated.PathHTTPHeaders + var leaseAccessConditions *generated.LeaseAccessConditions + var modifiedAccessConditions *generated.ModifiedAccessConditions + var cpkInfoOpts *generated.CPKInfo + flushDataOpts := &generated.PathClientFlushDataOptions{ContentLength: &contentLength, Position: &offset} + + if o == nil { + flushDataOpts.RetainUncommittedData = &defaultRetainUncommitted + flushDataOpts.Close = &defaultClose + return flushDataOpts, nil, nil, nil, nil, nil + } + + if o != nil { + if o.RetainUncommittedData == nil { + flushDataOpts.RetainUncommittedData = &defaultRetainUncommitted + } else { + flushDataOpts.RetainUncommittedData = o.RetainUncommittedData + } + if o.Close == nil { + flushDataOpts.Close = &defaultClose + } else { + flushDataOpts.Close = o.Close + } + leaseAccessConditions, modifiedAccessConditions = exported.FormatPathAccessConditions(o.AccessConditions) + if o.HTTPHeaders != nil { + httpHeaderOpts := generated.PathHTTPHeaders{} + httpHeaderOpts.ContentMD5 = o.HTTPHeaders.ContentMD5 + httpHeaderOpts.ContentType = o.HTTPHeaders.ContentType + httpHeaderOpts.CacheControl = o.HTTPHeaders.CacheControl + httpHeaderOpts.ContentDisposition = o.HTTPHeaders.ContentDisposition + httpHeaderOpts.ContentEncoding = o.HTTPHeaders.ContentEncoding + } + if o.CPKInfo != nil { + cpkInfoOpts := generated.CPKInfo{} + cpkInfoOpts.EncryptionKey = o.CPKInfo.EncryptionKey + cpkInfoOpts.EncryptionKeySHA256 = o.CPKInfo.EncryptionKeySHA256 + cpkInfoOpts.EncryptionAlgorithm = o.CPKInfo.EncryptionAlgorithm + } + } + return flushDataOpts, modifiedAccessConditions, leaseAccessConditions, httpHeaderOpts, cpkInfoOpts, nil +} + +// AppendDataOptions contains the optional parameters for the Client.UploadRange method. +type AppendDataOptions struct { + // TransactionalValidation specifies the transfer validation type to use. + // The default is nil (no transfer validation). + TransactionalValidation TransferValidationType + // LeaseAccessConditions contains optional parameters to access leased entity. + LeaseAccessConditions *LeaseAccessConditions + // HTTPHeaders contains the optional path HTTP headers to set when the file is created. + HTTPHeaders *HTTPHeaders + // CPKInfo contains optional parameters to perform encryption using customer-provided key. + CPKInfo *CPKInfo +} + +func (o *AppendDataOptions) format(offset int64, body io.ReadSeekCloser) (*generated.PathClientAppendDataOptions, *generated.LeaseAccessConditions, *generated.PathHTTPHeaders, *generated.CPKInfo, error) { + if offset < 0 || body == nil { + return nil, nil, nil, nil, errors.New("invalid argument: offset must be >= 0 and body must not be nil") + } + + count, err := shared.ValidateSeekableStreamAt0AndGetCount(body) + if err != nil { + return nil, nil, nil, nil, err + } + + if count == 0 { + return nil, nil, nil, nil, errors.New("invalid argument: body must contain readable data whose size is > 0") + } + + appendDataOptions := &generated.PathClientAppendDataOptions{} + httpRange := exported.FormatHTTPRange(HTTPRange{ + Offset: offset, + Count: count, + }) + if httpRange != nil { + appendDataOptions.Position = &offset + appendDataOptions.ContentLength = &count + } + + var leaseAccessConditions *LeaseAccessConditions + var httpHeaderOpts *generated.PathHTTPHeaders + var cpkInfoOpts *generated.CPKInfo + + if o != nil { + leaseAccessConditions = o.LeaseAccessConditions + if o.HTTPHeaders != nil { + httpHeaderOpts := generated.PathHTTPHeaders{} + httpHeaderOpts.ContentMD5 = o.HTTPHeaders.ContentMD5 + httpHeaderOpts.ContentType = o.HTTPHeaders.ContentType + httpHeaderOpts.CacheControl = o.HTTPHeaders.CacheControl + httpHeaderOpts.ContentDisposition = o.HTTPHeaders.ContentDisposition + httpHeaderOpts.ContentEncoding = o.HTTPHeaders.ContentEncoding + } + if o.CPKInfo != nil { + cpkInfoOpts := generated.CPKInfo{} + cpkInfoOpts.EncryptionKey = o.CPKInfo.EncryptionKey + cpkInfoOpts.EncryptionKeySHA256 = o.CPKInfo.EncryptionKeySHA256 + cpkInfoOpts.EncryptionAlgorithm = o.CPKInfo.EncryptionAlgorithm + } + } + if o != nil && o.TransactionalValidation != nil { + _, err = o.TransactionalValidation.Apply(body, appendDataOptions) + if err != nil { + return nil, nil, nil, nil, err + } + } + + return appendDataOptions, leaseAccessConditions, httpHeaderOpts, cpkInfoOpts, nil +} + +func (u *UploadStreamOptions) setDefaults() { + if u.Concurrency == 0 { + u.Concurrency = 1 + } + + if u.ChunkSize < _1MiB { + u.ChunkSize = _1MiB + } +} + +func (u *uploadFromReaderOptions) getAppendDataOptions() *AppendDataOptions { + if u == nil { + return nil + } + leaseAccessConditions, _ := exported.FormatPathAccessConditions(u.AccessConditions) + return &AppendDataOptions{ + LeaseAccessConditions: leaseAccessConditions, + HTTPHeaders: u.HTTPHeaders, + CPKInfo: u.CPKInfo, + } +} + +func (u *uploadFromReaderOptions) getFlushDataOptions() *FlushDataOptions { + if u == nil { + return nil + } + return &FlushDataOptions{ + AccessConditions: u.AccessConditions, + HTTPHeaders: u.HTTPHeaders, + CPKInfo: u.CPKInfo, + } +} + +func (u *UploadStreamOptions) getAppendDataOptions() *AppendDataOptions { + if u == nil { + return nil + } + leaseAccessConditions, _ := exported.FormatPathAccessConditions(u.AccessConditions) + return &AppendDataOptions{ + LeaseAccessConditions: leaseAccessConditions, + HTTPHeaders: u.HTTPHeaders, + CPKInfo: u.CPKInfo, + } +} + +func (u *UploadStreamOptions) getFlushDataOptions() *FlushDataOptions { + if u == nil { + return nil + } + return &FlushDataOptions{ + AccessConditions: u.AccessConditions, + HTTPHeaders: u.HTTPHeaders, + CPKInfo: u.CPKInfo, + } +} + +// DownloadStreamOptions contains the optional parameters for the Client.Download method. +type DownloadStreamOptions struct { + // When set to true and specified together with the Range, the service returns the MD5 hash for the range, as long as the + // range is less than or equal to 4 MB in size. + RangeGetContentMD5 *bool + + // Range specifies a range of bytes. The default value is all bytes. + Range *HTTPRange + + AccessConditions *AccessConditions + CPKInfo *CPKInfo + CPKScopeInfo *CPKScopeInfo +} + +func (o *DownloadStreamOptions) format() *blob.DownloadStreamOptions { + if o == nil { + return nil + } + + downloadStreamOptions := &blob.DownloadStreamOptions{} + if o.Range != nil { + downloadStreamOptions.Range = blob.HTTPRange{ + Offset: o.Range.Offset, + Count: o.Range.Count, + } + } + if o.CPKInfo != nil { + downloadStreamOptions.CPKInfo = &blob.CPKInfo{ + EncryptionKey: o.CPKInfo.EncryptionKey, + EncryptionKeySHA256: o.CPKInfo.EncryptionKeySHA256, + EncryptionAlgorithm: (*blob.EncryptionAlgorithmType)(o.CPKInfo.EncryptionAlgorithm), + } + } + + downloadStreamOptions.RangeGetContentMD5 = o.RangeGetContentMD5 + downloadStreamOptions.AccessConditions = exported.FormatBlobAccessConditions(o.AccessConditions) + downloadStreamOptions.CPKScopeInfo = (*blob.CPKScopeInfo)(o.CPKScopeInfo) + return downloadStreamOptions +} + +// DownloadBufferOptions contains the optional parameters for the DownloadBuffer method. +type DownloadBufferOptions struct { + // Range specifies a range of bytes. The default value is all bytes. + Range *HTTPRange + + // ChunkSize specifies the block size to use for each parallel download; the default size is DefaultDownloadBlockSize. + ChunkSize int64 + + // Progress is a function that is invoked periodically as bytes are received. + Progress func(bytesTransferred int64) + + // BlobAccessConditions indicates the access conditions used when making HTTP GET requests against the blob. + AccessConditions *AccessConditions + + // CPKInfo contains a group of parameters for client provided encryption key. + CPKInfo *CPKInfo + + // CPKScopeInfo contains a group of parameters for client provided encryption scope. + CPKScopeInfo *CPKScopeInfo + + // Concurrency indicates the maximum number of blocks to download in parallel (0=default). + Concurrency uint16 + + // RetryReaderOptionsPerChunk is used when downloading each block. + RetryReaderOptionsPerChunk *RetryReaderOptions +} + +func (o *DownloadBufferOptions) format() *blob.DownloadBufferOptions { + if o == nil { + return nil + } + + downloadBufferOptions := &blob.DownloadBufferOptions{} + if o.Range != nil { + downloadBufferOptions.Range = blob.HTTPRange{ + Offset: o.Range.Offset, + Count: o.Range.Count, + } + } + if o.CPKInfo != nil { + downloadBufferOptions.CPKInfo = &blob.CPKInfo{ + EncryptionKey: o.CPKInfo.EncryptionKey, + EncryptionKeySHA256: o.CPKInfo.EncryptionKeySHA256, + EncryptionAlgorithm: (*blob.EncryptionAlgorithmType)(o.CPKInfo.EncryptionAlgorithm), + } + } + + downloadBufferOptions.AccessConditions = exported.FormatBlobAccessConditions(o.AccessConditions) + downloadBufferOptions.CPKScopeInfo = (*blob.CPKScopeInfo)(o.CPKScopeInfo) + downloadBufferOptions.BlockSize = o.ChunkSize + downloadBufferOptions.Progress = o.Progress + downloadBufferOptions.Concurrency = o.Concurrency + if o.RetryReaderOptionsPerChunk != nil { + newFunc := func(failureCount int32, lastError error, rnge blob.HTTPRange, willRetry bool) { + newRange := HTTPRange{ + Offset: rnge.Offset, + Count: rnge.Count, + } + o.RetryReaderOptionsPerChunk.OnFailedRead(failureCount, lastError, newRange, willRetry) + } + downloadBufferOptions.RetryReaderOptionsPerBlock.OnFailedRead = newFunc + downloadBufferOptions.RetryReaderOptionsPerBlock.EarlyCloseAsError = o.RetryReaderOptionsPerChunk.EarlyCloseAsError + downloadBufferOptions.RetryReaderOptionsPerBlock.MaxRetries = o.RetryReaderOptionsPerChunk.MaxRetries + } + + return downloadBufferOptions +} + +// DownloadFileOptions contains the optional parameters for the Client.DownloadFile method. +type DownloadFileOptions struct { + // Range specifies a range of bytes. The default value is all bytes. + Range *HTTPRange + + // ChunkSize specifies the block size to use for each parallel download; the default size is DefaultDownloadBlockSize. + ChunkSize int64 + + // Progress is a function that is invoked periodically as bytes are received. + Progress func(bytesTransferred int64) + + // BlobAccessConditions indicates the access conditions used when making HTTP GET requests against the blob. + AccessConditions *AccessConditions + + // ClientProvidedKeyOptions indicates the client provided key by name and/or by value to encrypt/decrypt data. + CPKInfo *CPKInfo + CPKScopeInfo *CPKScopeInfo + + // Concurrency indicates the maximum number of blocks to download in parallel. The default value is 5. + Concurrency uint16 + + // RetryReaderOptionsPerChunk is used when downloading each block. + RetryReaderOptionsPerChunk *RetryReaderOptions +} + +func (o *DownloadFileOptions) format() *blob.DownloadFileOptions { + if o == nil { + return nil + } + + downloadFileOptions := &blob.DownloadFileOptions{} + if o.Range != nil { + downloadFileOptions.Range = blob.HTTPRange{ + Offset: o.Range.Offset, + Count: o.Range.Count, + } + } + if o.CPKInfo != nil { + downloadFileOptions.CPKInfo = &blob.CPKInfo{ + EncryptionKey: o.CPKInfo.EncryptionKey, + EncryptionKeySHA256: o.CPKInfo.EncryptionKeySHA256, + EncryptionAlgorithm: (*blob.EncryptionAlgorithmType)(o.CPKInfo.EncryptionAlgorithm), + } + } + + downloadFileOptions.AccessConditions = exported.FormatBlobAccessConditions(o.AccessConditions) + downloadFileOptions.CPKScopeInfo = (*blob.CPKScopeInfo)(o.CPKScopeInfo) + downloadFileOptions.BlockSize = o.ChunkSize + downloadFileOptions.Progress = o.Progress + downloadFileOptions.Concurrency = o.Concurrency + if o.RetryReaderOptionsPerChunk != nil { + newFunc := func(failureCount int32, lastError error, rnge blob.HTTPRange, willRetry bool) { + newRange := HTTPRange{ + Offset: rnge.Offset, + Count: rnge.Count, + } + o.RetryReaderOptionsPerChunk.OnFailedRead(failureCount, lastError, newRange, willRetry) + } + downloadFileOptions.RetryReaderOptionsPerBlock.OnFailedRead = newFunc + downloadFileOptions.RetryReaderOptionsPerBlock.EarlyCloseAsError = o.RetryReaderOptionsPerChunk.EarlyCloseAsError + downloadFileOptions.RetryReaderOptionsPerBlock.MaxRetries = o.RetryReaderOptionsPerChunk.MaxRetries + } + + return downloadFileOptions +} + // CreationExpiryType defines values for Create() ExpiryType type CreationExpiryType interface { Format() (generated.ExpiryOptions, *string) diff --git a/sdk/storage/azdatalake/file/responses.go b/sdk/storage/azdatalake/file/responses.go index 3116edab7355..a518ee58f376 100644 --- a/sdk/storage/azdatalake/file/responses.go +++ b/sdk/storage/azdatalake/file/responses.go @@ -7,8 +7,14 @@ package file import ( + "context" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/generated" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/path" + "io" + "time" ) // SetExpiryResponse contains the response fields for the SetExpiry operation. @@ -26,12 +32,229 @@ type UpdateAccessControlResponse = generated.PathClientSetAccessControlRecursive // RemoveAccessControlResponse contains the response fields for the RemoveAccessControlRecursive operation. type RemoveAccessControlResponse = generated.PathClientSetAccessControlRecursiveResponse +// AppendDataResponse contains the response from method Client.AppendData. +type AppendDataResponse = generated.PathClientAppendDataResponse + +// FlushDataResponse contains the response from method Client.FlushData. +type FlushDataResponse = generated.PathClientFlushDataResponse + // RenameResponse contains the response fields for the Create operation. type RenameResponse struct { Response generated.PathClientCreateResponse NewFileClient *Client } +// DownloadStreamResponse contains the response from the DownloadStream method. +// To read from the stream, read from the Body field, or call the NewRetryReader method. +type DownloadStreamResponse struct { + DownloadResponse + client *Client + getInfo httpGetterInfo + cpkInfo *CPKInfo + cpkScope *CPKScopeInfo +} + +// NewRetryReader constructs new RetryReader stream for reading data. If a connection fails while +// reading, it will make additional requests to reestablish a connection and continue reading. +// Pass nil for options to accept the default options. +// Callers of this method should not access the DownloadStreamResponse.Body field. +func (r *DownloadStreamResponse) NewRetryReader(ctx context.Context, options *RetryReaderOptions) *RetryReader { + if options == nil { + options = &RetryReaderOptions{} + } + + return newRetryReader(ctx, r.Body, r.getInfo, func(ctx context.Context, getInfo httpGetterInfo) (io.ReadCloser, error) { + accessConditions := &AccessConditions{ + ModifiedAccessConditions: &ModifiedAccessConditions{IfMatch: getInfo.ETag}, + } + options := DownloadStreamOptions{ + Range: getInfo.Range, + AccessConditions: accessConditions, + CPKInfo: r.cpkInfo, + CPKScopeInfo: r.cpkScope, + } + resp, err := r.client.DownloadStream(ctx, &options) + if err != nil { + return nil, err + } + return resp.Body, err + }, *options) +} + +// DownloadResponse contains the response fields for the GetProperties operation. +type DownloadResponse struct { + // AcceptRanges contains the information returned from the Accept-Ranges header response. + AcceptRanges *string + + // Body contains the streaming response. + Body io.ReadCloser + + // CacheControl contains the information returned from the Cache-Control header response. + CacheControl *string + + // ClientRequestID contains the information returned from the x-ms-client-request-id header response. + ClientRequestID *string + + // ContentCRC64 contains the information returned from the x-ms-content-crc64 header response. + ContentCRC64 []byte + + // ContentDisposition contains the information returned from the Content-Disposition header response. + ContentDisposition *string + + // ContentEncoding contains the information returned from the Content-Encoding header response. + ContentEncoding *string + + // ContentLanguage contains the information returned from the Content-Language header response. + ContentLanguage *string + + // ContentLength contains the information returned from the Content-Length header response. + ContentLength *int64 + + // ContentMD5 contains the information returned from the Content-MD5 header response. + ContentMD5 []byte + + // ContentRange contains the information returned from the Content-Range header response. + ContentRange *string + + // ContentType contains the information returned from the Content-Type header response. + ContentType *string + + // CopyCompletionTime contains the information returned from the x-ms-copy-completion-time header response. + CopyCompletionTime *time.Time + + // CopyID contains the information returned from the x-ms-copy-id header response. + CopyID *string + + // CopyProgress contains the information returned from the x-ms-copy-progress header response. + CopyProgress *string + + // CopySource contains the information returned from the x-ms-copy-source header response. + CopySource *string + + // CopyStatus contains the information returned from the x-ms-copy-status header response. + CopyStatus *CopyStatusType + + // CopyStatusDescription contains the information returned from the x-ms-copy-status-description header response. + CopyStatusDescription *string + + // Date contains the information returned from the Date header response. + Date *time.Time + + // ETag contains the information returned from the ETag header response. + ETag *azcore.ETag + + // EncryptionKeySHA256 contains the information returned from the x-ms-encryption-key-sha256 header response. + EncryptionKeySHA256 *string + + // EncryptionScope contains the information returned from the x-ms-encryption-scope header response. + EncryptionScope *string + + // ErrorCode contains the information returned from the x-ms-error-code header response. + ErrorCode *string + + // ImmutabilityPolicyExpiresOn contains the information returned from the x-ms-immutability-policy-until-date header response. + ImmutabilityPolicyExpiresOn *time.Time + + // ImmutabilityPolicyMode contains the information returned from the x-ms-immutability-policy-mode header response. + ImmutabilityPolicyMode *ImmutabilityPolicyMode + + // IsCurrentVersion contains the information returned from the x-ms-is-current-version header response. + IsCurrentVersion *bool + + // IsSealed contains the information returned from the x-ms-blob-sealed header response. + IsSealed *bool + + // IsServerEncrypted contains the information returned from the x-ms-server-encrypted header response. + IsServerEncrypted *bool + + // LastAccessed contains the information returned from the x-ms-last-access-time header response. + LastAccessed *time.Time + + // LastModified contains the information returned from the Last-Modified header response. + LastModified *time.Time + + // LeaseDuration contains the information returned from the x-ms-lease-duration header response. + LeaseDuration *azdatalake.DurationType + + // LeaseState contains the information returned from the x-ms-lease-state header response. + LeaseState *azdatalake.StateType + + // LeaseStatus contains the information returned from the x-ms-lease-status header response. + LeaseStatus *azdatalake.StatusType + + // LegalHold contains the information returned from the x-ms-legal-hold header response. + LegalHold *bool + + // Metadata contains the information returned from the x-ms-meta header response. + Metadata map[string]*string + + // ObjectReplicationPolicyID contains the information returned from the x-ms-or-policy-id header response. + ObjectReplicationPolicyID *string + + // ObjectReplicationRules contains the information returned from the x-ms-or header response. + ObjectReplicationRules map[string]*string + + // RequestID contains the information returned from the x-ms-request-id header response. + RequestID *string + + // TagCount contains the information returned from the x-ms-tag-count header response. + TagCount *int64 + + // Version contains the information returned from the x-ms-version header response. + Version *string + + // VersionID contains the information returned from the x-ms-version-id header response. + VersionID *string +} + +func FormatDownloadStreamResponse(r *blob.DownloadStreamResponse) DownloadResponse { + newResp := DownloadResponse{} + if r != nil { + newResp.AcceptRanges = r.AcceptRanges + newResp.Body = r.Body + newResp.ContentCRC64 = r.ContentCRC64 + newResp.ContentRange = r.ContentRange + newResp.CacheControl = r.CacheControl + newResp.ErrorCode = r.ErrorCode + newResp.ClientRequestID = r.ClientRequestID + newResp.ContentDisposition = r.ContentDisposition + newResp.ContentEncoding = r.ContentEncoding + newResp.ContentLanguage = r.ContentLanguage + newResp.ContentLength = r.ContentLength + newResp.ContentMD5 = r.ContentMD5 + newResp.ContentType = r.ContentType + newResp.CopyCompletionTime = r.CopyCompletionTime + newResp.CopyID = r.CopyID + newResp.CopyProgress = r.CopyProgress + newResp.CopySource = r.CopySource + newResp.CopyStatus = r.CopyStatus + newResp.CopyStatusDescription = r.CopyStatusDescription + newResp.Date = r.Date + newResp.ETag = r.ETag + newResp.EncryptionKeySHA256 = r.EncryptionKeySHA256 + newResp.EncryptionScope = r.EncryptionScope + newResp.ImmutabilityPolicyExpiresOn = r.ImmutabilityPolicyExpiresOn + newResp.ImmutabilityPolicyMode = r.ImmutabilityPolicyMode + newResp.IsCurrentVersion = r.IsCurrentVersion + newResp.IsSealed = r.IsSealed + newResp.IsServerEncrypted = r.IsServerEncrypted + newResp.LastAccessed = r.LastAccessed + newResp.LastModified = r.LastModified + newResp.LeaseDuration = r.LeaseDuration + newResp.LeaseState = r.LeaseState + newResp.LeaseStatus = r.LeaseStatus + newResp.LegalHold = r.LegalHold + newResp.Metadata = r.Metadata + newResp.ObjectReplicationPolicyID = r.ObjectReplicationPolicyID + newResp.ObjectReplicationRules = r.DownloadResponse.ObjectReplicationRules + newResp.RequestID = r.RequestID + newResp.TagCount = r.TagCount + newResp.Version = r.Version + newResp.VersionID = r.VersionID + } + return newResp +} + // ========================================== path imports =========================================================== // SetAccessControlResponse contains the response fields for the SetAccessControl operation. diff --git a/sdk/storage/azdatalake/file/retry_reader.go b/sdk/storage/azdatalake/file/retry_reader.go new file mode 100644 index 000000000000..66e3f35edf0d --- /dev/null +++ b/sdk/storage/azdatalake/file/retry_reader.go @@ -0,0 +1,191 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package file + +import ( + "context" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "io" + "net" + "strings" + "sync" +) + +// HTTPGetter is a function type that refers to a method that performs an HTTP GET operation. +type httpGetter func(ctx context.Context, i httpGetterInfo) (io.ReadCloser, error) + +// HTTPGetterInfo is passed to an HTTPGetter function passing it parameters +// that should be used to make an HTTP GET request. +type httpGetterInfo struct { + Range *HTTPRange + + // ETag specifies the resource's etag that should be used when creating + // the HTTP GET request's If-Match header + ETag *azcore.ETag +} + +// RetryReaderOptions configures the retry reader's behavior. +// Zero-value fields will have their specified default values applied during use. +// This allows for modification of a subset of fields. +type RetryReaderOptions struct { + // MaxRetries specifies the maximum number of attempts a failed read will be retried + // before producing an error. + // The default value is three. + MaxRetries int32 + + // OnFailedRead, when non-nil, is called after any failure to read. Expected usage is diagnostic logging. + OnFailedRead func(failureCount int32, lastError error, rnge HTTPRange, willRetry bool) + + // EarlyCloseAsError can be set to true to prevent retries after "read on closed response body". By default, + // retryReader has the following special behaviour: closing the response body before it is all read is treated as a + // retryable error. This is to allow callers to force a retry by closing the body from another goroutine (e.g. if the = + // read is too slow, caller may want to force a retry in the hope that the retry will be quicker). If + // TreatEarlyCloseAsError is true, then retryReader's special behaviour is suppressed, and "read on closed body" is instead + // treated as a fatal (non-retryable) error. + // Note that setting TreatEarlyCloseAsError only guarantees that Closing will produce a fatal error if the Close happens + // from the same "thread" (goroutine) as Read. Concurrent Close calls from other goroutines may instead produce network errors + // which will be retried. + // The default value is false. + EarlyCloseAsError bool + + doInjectError bool + doInjectErrorRound int32 + injectedError error +} + +// RetryReader attempts to read from response, and if there is a retry-able network error +// returned during reading, it will retry according to retry reader option through executing +// user defined action with provided data to get a new response, and continue the overall reading process +// through reading from the new response. +// RetryReader implements the io.ReadCloser interface. +type RetryReader struct { + ctx context.Context + info httpGetterInfo + retryReaderOptions RetryReaderOptions + getter httpGetter + countWasBounded bool + + // we support Close-ing during Reads (from other goroutines), so we protect the shared state, which is response + responseMu *sync.Mutex + response io.ReadCloser +} + +// newRetryReader creates a retry reader. +func newRetryReader(ctx context.Context, initialResponse io.ReadCloser, info httpGetterInfo, getter httpGetter, o RetryReaderOptions) *RetryReader { + if o.MaxRetries < 1 { + o.MaxRetries = 3 + } + return &RetryReader{ + ctx: ctx, + getter: getter, + info: info, + countWasBounded: info.Range.Count != CountToEnd, + response: initialResponse, + responseMu: &sync.Mutex{}, + retryReaderOptions: o, + } +} + +// setResponse function +func (s *RetryReader) setResponse(r io.ReadCloser) { + s.responseMu.Lock() + defer s.responseMu.Unlock() + s.response = r +} + +// Read from retry reader +func (s *RetryReader) Read(p []byte) (n int, err error) { + for try := int32(0); ; try++ { + //fmt.Println(try) // Comment out for debugging. + if s.countWasBounded && s.info.Range.Count == CountToEnd { + // User specified an original count and the remaining bytes are 0, return 0, EOF + return 0, io.EOF + } + + s.responseMu.Lock() + resp := s.response + s.responseMu.Unlock() + if resp == nil { // We don't have a response stream to read from, try to get one. + newResponse, err := s.getter(s.ctx, s.info) + if err != nil { + return 0, err + } + // Successful GET; this is the network stream we'll read from. + s.setResponse(newResponse) + resp = newResponse + } + n, err := resp.Read(p) // Read from the stream (this will return non-nil err if forceRetry is called, from another goroutine, while it is running) + + // Injection mechanism for testing. + if s.retryReaderOptions.doInjectError && try == s.retryReaderOptions.doInjectErrorRound { + if s.retryReaderOptions.injectedError != nil { + err = s.retryReaderOptions.injectedError + } else { + err = &net.DNSError{IsTemporary: true} + } + } + + // We successfully read data or end EOF. + if err == nil || err == io.EOF { + s.info.Range.Offset += int64(n) // Increments the start offset in case we need to make a new HTTP request in the future + if s.info.Range.Count != CountToEnd { + s.info.Range.Count -= int64(n) // Decrement the count in case we need to make a new HTTP request in the future + } + return n, err // Return the return to the caller + } + _ = s.Close() + + s.setResponse(nil) // Our stream is no longer good + + // Check the retry count and error code, and decide whether to retry. + retriesExhausted := try >= s.retryReaderOptions.MaxRetries + _, isNetError := err.(net.Error) + isUnexpectedEOF := err == io.ErrUnexpectedEOF + willRetry := (isNetError || isUnexpectedEOF || s.wasRetryableEarlyClose(err)) && !retriesExhausted + + // Notify, for logging purposes, of any failures + if s.retryReaderOptions.OnFailedRead != nil { + failureCount := try + 1 // because try is zero-based + s.retryReaderOptions.OnFailedRead(failureCount, err, *s.info.Range, willRetry) + } + + if willRetry { + continue + // Loop around and try to get and read from new stream. + } + return n, err // Not retryable, or retries exhausted, so just return + } +} + +// By default, we allow early Closing, from another concurrent goroutine, to be used to force a retry +// Is this safe, to close early from another goroutine? Early close ultimately ends up calling +// net.Conn.Close, and that is documented as "Any blocked Read or Write operations will be unblocked and return errors" +// which is exactly the behaviour we want. +// NOTE: that if caller has forced an early Close from a separate goroutine (separate from the Read) +// then there are two different types of error that may happen - either the one we check for here, +// or a net.Error (due to closure of connection). Which one happens depends on timing. We only need this routine +// to check for one, since the other is a net.Error, which our main Read retry loop is already handing. +func (s *RetryReader) wasRetryableEarlyClose(err error) bool { + if s.retryReaderOptions.EarlyCloseAsError { + return false // user wants all early closes to be errors, and so not retryable + } + // unfortunately, http.errReadOnClosedResBody is private, so the best we can do here is to check for its text + return strings.HasSuffix(err.Error(), ReadOnClosedBodyMessage) +} + +// ReadOnClosedBodyMessage of retry reader +const ReadOnClosedBodyMessage = "read on closed response body" + +// Close retry reader +func (s *RetryReader) Close() error { + s.responseMu.Lock() + defer s.responseMu.Unlock() + if s.response != nil { + return s.response.Close() + } + return nil +} diff --git a/sdk/storage/azdatalake/filesystem/client.go b/sdk/storage/azdatalake/filesystem/client.go index c0a8146de2ad..c0c7ee5b8086 100644 --- a/sdk/storage/azdatalake/filesystem/client.go +++ b/sdk/storage/azdatalake/filesystem/client.go @@ -162,8 +162,8 @@ func (fs *Client) containerClient() *container.Client { return containerClient } -func (f *Client) identityCredential() *azcore.TokenCredential { - return base.IdentityCredentialComposite((*base.CompositeClient[generated.FileSystemClient, generated.FileSystemClient, container.Client])(f)) +func (fs *Client) identityCredential() *azcore.TokenCredential { + return base.IdentityCredentialComposite((*base.CompositeClient[generated.FileSystemClient, generated.FileSystemClient, container.Client])(fs)) } func (fs *Client) sharedKey() *exported.SharedKeyCredential { diff --git a/sdk/storage/azdatalake/internal/exported/shared_key_credential.go b/sdk/storage/azdatalake/internal/exported/shared_key_credential.go index 63539ea0b10a..d54cdc3a0b76 100644 --- a/sdk/storage/azdatalake/internal/exported/shared_key_credential.go +++ b/sdk/storage/azdatalake/internal/exported/shared_key_credential.go @@ -182,7 +182,7 @@ func (c *SharedKeyCredential) buildCanonicalizedResource(u *url.URL) (string, er // Join the sorted key values separated by ',' // Then prepend "keyName:"; then add this string to the buffer - cr.WriteString("\n" + paramName + ":" + strings.Join(paramValues, ",")) + cr.WriteString("\n" + strings.ToLower(paramName) + ":" + strings.Join(paramValues, ",")) } } return cr.String(), nil diff --git a/sdk/storage/azdatalake/internal/exported/transfer_validation_option.go b/sdk/storage/azdatalake/internal/exported/transfer_validation_option.go new file mode 100644 index 000000000000..85430ebd1c7e --- /dev/null +++ b/sdk/storage/azdatalake/internal/exported/transfer_validation_option.go @@ -0,0 +1,56 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package exported + +import ( + "bytes" + "encoding/binary" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/generated" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/shared" + "hash/crc64" + "io" +) + +// TransferValidationType abstracts the various mechanisms used to verify a transfer. +type TransferValidationType interface { + Apply(io.ReadSeekCloser, generated.TransactionalContentSetter) (io.ReadSeekCloser, error) + notPubliclyImplementable() +} + +// TransferValidationTypeCRC64 is a TransferValidationType used to provide a precomputed CRC64. +type TransferValidationTypeCRC64 uint64 + +func (c TransferValidationTypeCRC64) Apply(rsc io.ReadSeekCloser, cfg generated.TransactionalContentSetter) (io.ReadSeekCloser, error) { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(c)) + cfg.SetCRC64(buf) + return rsc, nil +} + +func (TransferValidationTypeCRC64) notPubliclyImplementable() {} + +// TransferValidationTypeComputeCRC64 is a TransferValidationType that indicates a CRC64 should be computed during transfer. +func TransferValidationTypeComputeCRC64() TransferValidationType { + return transferValidationTypeFn(func(rsc io.ReadSeekCloser, cfg generated.TransactionalContentSetter) (io.ReadSeekCloser, error) { + buf, err := io.ReadAll(rsc) + if err != nil { + return nil, err + } + + crc := crc64.Checksum(buf, shared.CRC64Table) + return TransferValidationTypeCRC64(crc).Apply(streaming.NopCloser(bytes.NewReader(buf)), cfg) + }) +} + +type transferValidationTypeFn func(io.ReadSeekCloser, generated.TransactionalContentSetter) (io.ReadSeekCloser, error) + +func (t transferValidationTypeFn) Apply(rsc io.ReadSeekCloser, cfg generated.TransactionalContentSetter) (io.ReadSeekCloser, error) { + return t(rsc, cfg) +} + +func (transferValidationTypeFn) notPubliclyImplementable() {} diff --git a/sdk/storage/azdatalake/internal/generated/models.go b/sdk/storage/azdatalake/internal/generated/models.go new file mode 100644 index 000000000000..b3f86d5973cb --- /dev/null +++ b/sdk/storage/azdatalake/internal/generated/models.go @@ -0,0 +1,15 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package generated + +type TransactionalContentSetter interface { + SetCRC64([]byte) +} + +func (a *PathClientAppendDataOptions) SetCRC64(v []byte) { + a.TransactionalContentCRC64 = v +} diff --git a/sdk/storage/azdatalake/internal/path/constants.go b/sdk/storage/azdatalake/internal/path/constants.go index 7dd11049e38e..ce070f694d23 100644 --- a/sdk/storage/azdatalake/internal/path/constants.go +++ b/sdk/storage/azdatalake/internal/path/constants.go @@ -6,14 +6,16 @@ package path -import "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" +import ( + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/generated" +) -// EncryptionAlgorithmType defines values for EncryptionAlgorithmType. -type EncryptionAlgorithmType = blob.EncryptionAlgorithmType +type EncryptionAlgorithmType = generated.EncryptionAlgorithmType const ( - EncryptionAlgorithmTypeNone EncryptionAlgorithmType = blob.EncryptionAlgorithmTypeNone - EncryptionAlgorithmTypeAES256 EncryptionAlgorithmType = blob.EncryptionAlgorithmTypeAES256 + EncryptionAlgorithmTypeNone EncryptionAlgorithmType = generated.EncryptionAlgorithmTypeNone + EncryptionAlgorithmTypeAES256 EncryptionAlgorithmType = generated.EncryptionAlgorithmTypeAES256 ) type ImmutabilityPolicyMode = blob.ImmutabilityPolicyMode diff --git a/sdk/storage/azdatalake/internal/path/models.go b/sdk/storage/azdatalake/internal/path/models.go index 893bc9d40d19..f476fcae518f 100644 --- a/sdk/storage/azdatalake/internal/path/models.go +++ b/sdk/storage/azdatalake/internal/path/models.go @@ -30,7 +30,7 @@ func FormatGetPropertiesOptions(o *GetPropertiesOptions) *blob.GetPropertiesOpti AccessConditions: accessConditions, CPKInfo: &blob.CPKInfo{ EncryptionKey: o.CPKInfo.EncryptionKey, - EncryptionAlgorithm: o.CPKInfo.EncryptionAlgorithm, + EncryptionAlgorithm: (*blob.EncryptionAlgorithmType)(o.CPKInfo.EncryptionAlgorithm), EncryptionKeySHA256: o.CPKInfo.EncryptionKeySHA256, }, } @@ -159,19 +159,18 @@ type HTTPHeaders struct { ContentType *string } -// -//func (o HTTPHeaders) formatBlobHTTPHeaders() blob.HTTPHeaders { -// -// opts := blob.HTTPHeaders{ -// BlobCacheControl: o.CacheControl, -// BlobContentDisposition: o.ContentDisposition, -// BlobContentEncoding: o.ContentEncoding, -// BlobContentLanguage: o.ContentLanguage, -// BlobContentMD5: o.ContentMD5, -// BlobContentType: o.ContentType, -// } -// return opts -//} +func FormatBlobHTTPHeaders(o *HTTPHeaders) *blob.HTTPHeaders { + + opts := &blob.HTTPHeaders{ + BlobCacheControl: o.CacheControl, + BlobContentDisposition: o.ContentDisposition, + BlobContentEncoding: o.ContentEncoding, + BlobContentLanguage: o.ContentLanguage, + BlobContentMD5: o.ContentMD5, + BlobContentType: o.ContentType, + } + return opts +} func FormatPathHTTPHeaders(o *HTTPHeaders) *generated.PathHTTPHeaders { // TODO: will be used for file related ops, like append @@ -209,7 +208,7 @@ func FormatSetMetadataOptions(o *SetMetadataOptions) (*blob.SetMetadataOptions, if o.CPKInfo != nil { opts.CPKInfo = &blob.CPKInfo{ EncryptionKey: o.CPKInfo.EncryptionKey, - EncryptionAlgorithm: o.CPKInfo.EncryptionAlgorithm, + EncryptionAlgorithm: (*blob.EncryptionAlgorithmType)(o.CPKInfo.EncryptionAlgorithm), EncryptionKeySHA256: o.CPKInfo.EncryptionKeySHA256, } } diff --git a/sdk/storage/azdatalake/internal/shared/batch_transfer.go b/sdk/storage/azdatalake/internal/shared/batch_transfer.go new file mode 100644 index 000000000000..ec5541bfbb13 --- /dev/null +++ b/sdk/storage/azdatalake/internal/shared/batch_transfer.go @@ -0,0 +1,77 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package shared + +import ( + "context" + "errors" +) + +// BatchTransferOptions identifies options used by doBatchTransfer. +type BatchTransferOptions struct { + TransferSize int64 + ChunkSize int64 + Concurrency uint16 + Operation func(ctx context.Context, offset int64, chunkSize int64) error + OperationName string +} + +// DoBatchTransfer helps to execute operations in a batch manner. +// Can be used by users to customize batch works (for other scenarios that the SDK does not provide) +func DoBatchTransfer(ctx context.Context, o *BatchTransferOptions) error { + if o.ChunkSize == 0 { + return errors.New("ChunkSize cannot be 0") + } + + if o.Concurrency == 0 { + o.Concurrency = 5 // default concurrency + } + + // Prepare and do parallel operations. + numChunks := uint16(((o.TransferSize - 1) / o.ChunkSize) + 1) + operationChannel := make(chan func() error, o.Concurrency) // Create the channel that release 'concurrency' goroutines concurrently + operationResponseChannel := make(chan error, numChunks) // Holds each response + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Create the goroutines that process each operation (in parallel). + for g := uint16(0); g < o.Concurrency; g++ { + //grIndex := g + go func() { + for f := range operationChannel { + err := f() + operationResponseChannel <- err + } + }() + } + + // Add each chunk's operation to the channel. + for chunkNum := uint16(0); chunkNum < numChunks; chunkNum++ { + curChunkSize := o.ChunkSize + + if chunkNum == numChunks-1 { // Last chunk + curChunkSize = o.TransferSize - (int64(chunkNum) * o.ChunkSize) // Remove size of all transferred chunks from total + } + offset := int64(chunkNum) * o.ChunkSize + operationChannel <- func() error { + return o.Operation(ctx, offset, curChunkSize) + } + } + close(operationChannel) + + // Wait for the operations to complete. + var firstErr error = nil + for chunkNum := uint16(0); chunkNum < numChunks; chunkNum++ { + responseError := <-operationResponseChannel + // record the first error (the original error which should cause the other chunks to fail with canceled context) + if responseError != nil && firstErr == nil { + cancel() // As soon as any operation fails, cancel all remaining operation calls + firstErr = responseError + } + } + return firstErr +} diff --git a/sdk/storage/azdatalake/internal/testcommon/common.go b/sdk/storage/azdatalake/internal/testcommon/common.go index 1314309c5ac2..36af75cc92a7 100644 --- a/sdk/storage/azdatalake/internal/testcommon/common.go +++ b/sdk/storage/azdatalake/internal/testcommon/common.go @@ -1,11 +1,14 @@ package testcommon import ( + "bytes" "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/datalakeerror" "github.com/stretchr/testify/require" + "io" "os" "strings" "testing" @@ -80,3 +83,20 @@ func ValidateErrorCode(_require *require.Assertions, err error, code datalakeerr func GetRelativeTimeFromAnchor(anchorTime *time.Time, amount time.Duration) time.Time { return anchorTime.Add(amount * time.Second) } + +const random64BString string = "2SDgZj6RkKYzJpu04sweQek4uWHO8ndPnYlZ0tnFS61hjnFZ5IkvIGGY44eKABov" + +func GenerateData(sizeInBytes int) (io.ReadSeekCloser, []byte) { + data := make([]byte, sizeInBytes) + _len := len(random64BString) + if sizeInBytes > _len { + count := sizeInBytes / _len + if sizeInBytes%_len != 0 { + count = count + 1 + } + copy(data[:], strings.Repeat(random64BString, count)) + } else { + copy(data[:], random64BString) + } + return streaming.NopCloser(bytes.NewReader(data)), data +} diff --git a/sdk/storage/azdatalake/sas/service.go b/sdk/storage/azdatalake/sas/service.go index 86a292028276..92ccaa8101a3 100644 --- a/sdk/storage/azdatalake/sas/service.go +++ b/sdk/storage/azdatalake/sas/service.go @@ -55,7 +55,7 @@ func getDirectoryDepth(path string) string { // SignWithSharedKey uses an account's SharedKeyCredential to sign this signature values to produce the proper SAS query parameters. func (v DatalakeSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCredential) (QueryParameters, error) { - if v.ExpiryTime.IsZero() || v.Permissions == "" { + if v.Identifier == "" && v.ExpiryTime.IsZero() || v.Permissions == "" { return QueryParameters{}, errors.New("service SAS is missing at least one of these: ExpiryTime or Permissions") } @@ -118,7 +118,6 @@ func (v DatalakeSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKe // Container/Blob-specific SAS parameters resource: resource, - identifier: v.Identifier, cacheControl: v.CacheControl, contentDisposition: v.ContentDisposition, contentEncoding: v.ContentEncoding, @@ -129,7 +128,8 @@ func (v DatalakeSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKe unauthorizedObjectID: v.UnauthorizedObjectID, correlationID: v.CorrelationID, // Calculated SAS signature - signature: signature, + signature: signature, + identifier: signedIdentifier, } return p, nil