Skip to content

Commit

Permalink
[storage][azblob]DownloadFile: Download large files serially (#21259)
Browse files Browse the repository at this point in the history
* Move Buffer manager to shared

* Serialize downloading to a file

* Move DownloadFile to new func

* Add special handling for small files

* Fix build

* Fix build

* Lint error

* Fix tab spaces

* Fix lint again :(

* Address comments

* Update comment

* Doc comment for default concurrency

* Fix formatting

* Fix formatting

* Fix file read performUploadAndDownloadFileTest() method
Seek to zero in performUploadAndDownloadFileTest before reading.

* Fix testcase TestBasicDoBatchTransfer

* Fix testcase
  • Loading branch information
nakulkar-msft authored Aug 17, 2023
1 parent 9d6efae commit 1dc804a
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 100 deletions.
175 changes: 169 additions & 6 deletions sdk/storage/azblob/blob/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ package blob

import (
"context"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
"io"
"os"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/bloberror"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/base"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/exported"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/generated"
Expand Down Expand Up @@ -324,8 +324,8 @@ func (b *Client) GetSASURL(permissions sas.BlobPermissions, expiry time.Time, o

// Concurrent Download Functions -----------------------------------------------------------------------------------------

// download downloads an Azure blob to a WriterAt in parallel.
func (b *Client) download(ctx context.Context, writer io.WriterAt, o downloadOptions) (int64, error) {
// downloadBuffer downloads an Azure blob to a WriterAt in parallel.
func (b *Client) downloadBuffer(ctx context.Context, writer io.WriterAt, o downloadOptions) (int64, error) {
if o.BlockSize == 0 {
o.BlockSize = DefaultDownloadBlockSize
}
Expand Down Expand Up @@ -353,6 +353,7 @@ func (b *Client) download(ctx context.Context, writer io.WriterAt, o downloadOpt
OperationName: "downloadBlobToWriterAt",
TransferSize: count,
ChunkSize: o.BlockSize,
NumChunks: uint16(((count - 1) / o.BlockSize) + 1),
Concurrency: o.Concurrency,
Operation: func(ctx context.Context, chunkStart int64, count int64) error {
downloadBlobOptions := o.getDownloadBlobOptions(HTTPRange{
Expand Down Expand Up @@ -391,6 +392,168 @@ func (b *Client) download(ctx context.Context, writer io.WriterAt, o downloadOpt
return count, nil
}

// downloadFile downloads an Azure blob to a Writer. The blocks are downloaded parallely,
// but written to file serially
func (b *Client) downloadFile(ctx context.Context, writer io.Writer, o downloadOptions) (int64, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if o.BlockSize == 0 {
o.BlockSize = DefaultDownloadBlockSize
}

if o.Concurrency == 0 {
o.Concurrency = DefaultConcurrency
}

count := o.Range.Count
if count == CountToEnd { //Calculate size if not specified
gr, err := b.GetProperties(ctx, o.getBlobPropertiesOptions())
if err != nil {
return 0, err
}
count = *gr.ContentLength - o.Range.Offset
}

if count <= 0 {
// The file is empty, there is nothing to download.
return 0, nil
}

progress := int64(0)
progressLock := &sync.Mutex{}

// helper routine to get body
getBodyForRange := func(ctx context.Context, chunkStart, size int64) (io.ReadCloser, error) {
downloadBlobOptions := o.getDownloadBlobOptions(HTTPRange{
Offset: chunkStart + o.Range.Offset,
Count: size,
}, nil)
dr, err := b.DownloadStream(ctx, downloadBlobOptions)
if err != nil {
return nil, err
}

var body io.ReadCloser = dr.NewRetryReader(ctx, &o.RetryReaderOptionsPerBlock)
if o.Progress != nil {
rangeProgress := int64(0)
body = streaming.NewResponseProgress(
body,
func(bytesTransferred int64) {
diff := bytesTransferred - rangeProgress
rangeProgress = bytesTransferred
progressLock.Lock()
progress += diff
o.Progress(progress)
progressLock.Unlock()
})
}

return body, nil
}

// if file fits in a single buffer, we'll download here.
if count <= o.BlockSize {
body, err := getBodyForRange(ctx, int64(0), count)
if err != nil {
return 0, err
}
defer body.Close()

return io.Copy(writer, body)
}

buffers := shared.NewMMBPool(int(o.Concurrency), o.BlockSize)
defer buffers.Free()
aquireBuffer := func() ([]byte, error) {
select {
case b := <-buffers.Acquire():
// got a buffer
return b, nil
default:
// no buffer available; allocate a new buffer if possible
if _, err := buffers.Grow(); err != nil {
return nil, err
}

// either grab the newly allocated buffer or wait for one to become available
return <-buffers.Acquire(), nil
}
}

numChunks := uint16((count-1)/o.BlockSize) + 1
blocks := make([]chan []byte, numChunks)
for b := range blocks {
blocks[b] = make(chan []byte)
}

/*
* We have created as many channels as the number of chunks we have.
* Each downloaded block will be sent to the channel matching its
* sequece number, i.e. 0th block is sent to 0th channel, 1st block
* to 1st channel and likewise. The blocks are then read and written
* to the file serially by below goroutine. Do note that the blocks
* blocks are still downloaded parallelly from n/w, only serailized
* and written to file here.
*/
writerError := make(chan error)
go func(ch chan error) {
for _, block := range blocks {
select {
case <-ctx.Done():
return
case block := <-block:
_, err := writer.Write(block)
buffers.Release(block)
if err != nil {
ch <- err
return
}
}
}
ch <- nil
}(writerError)

// Prepare and do parallel download.
err := shared.DoBatchTransfer(ctx, &shared.BatchTransferOptions{
OperationName: "downloadBlobToWriterAt",
TransferSize: count,
ChunkSize: o.BlockSize,
NumChunks: numChunks,
Concurrency: o.Concurrency,
Operation: func(ctx context.Context, chunkStart int64, count int64) error {
buff, err := aquireBuffer()
if err != nil {
return err
}

body, err := getBodyForRange(ctx, chunkStart, count)
if err != nil {
buffers.Release(buff)
return nil
}

_, err = io.ReadFull(body, buff[:count])
body.Close()
if err != nil {
return err
}

blockIndex := (chunkStart / o.BlockSize)
blocks[blockIndex] <- buff
return nil
},
})

if err != nil {
return 0, err
}
// error from writer thread.
if err = <-writerError; err != nil {
return 0, err
}
return count, nil
}

// DownloadStream reads a range of bytes from a blob. The response also includes the blob's properties and metadata.
// For more information, see https://docs.microsoft.com/rest/api/storageservices/get-blob.
func (b *Client) DownloadStream(ctx context.Context, o *DownloadStreamOptions) (DownloadStreamResponse, error) {
Expand Down Expand Up @@ -419,7 +582,7 @@ func (b *Client) DownloadBuffer(ctx context.Context, buffer []byte, o *DownloadB
if o == nil {
o = &DownloadBufferOptions{}
}
return b.download(ctx, shared.NewBytesWriter(buffer), (downloadOptions)(*o))
return b.downloadBuffer(ctx, shared.NewBytesWriter(buffer), (downloadOptions)(*o))
}

// DownloadFile downloads an Azure blob to a local file.
Expand Down Expand Up @@ -458,7 +621,7 @@ func (b *Client) DownloadFile(ctx context.Context, file *os.File, o *DownloadFil
}

if size > 0 {
return b.download(ctx, file, *do)
return b.downloadFile(ctx, file, *do)
} else { // if the blob's size is 0, there is no need in downloading it
return 0, nil
}
Expand Down
4 changes: 4 additions & 0 deletions sdk/storage/azblob/blob/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package blob
import (
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/exported"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/generated"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/shared"
)

const (
Expand All @@ -18,6 +19,9 @@ const (

// DefaultDownloadBlockSize is default block size
DefaultDownloadBlockSize = int64(4 * 1024 * 1024) // 4MB

// DefaultConcurrency is the default number of blocks downloaded or uploaded in parallel
DefaultConcurrency = shared.DefaultConcurrency
)

// BlobType defines values for BlobType
Expand Down
68 changes: 2 additions & 66 deletions sdk/storage/azblob/blockblob/chunkwriting.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
"github.com/Azure/azure-sdk-for-go/sdk/internal/uuid"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/shared"
)

// blockWriter provides methods to upload blocks that represent a file to a server and commit them.
Expand All @@ -28,27 +29,8 @@ type blockWriter interface {
CommitBlockList(context.Context, []string, *CommitBlockListOptions) (CommitBlockListResponse, error)
}

// bufferManager provides an abstraction for the management of buffers.
// this is mostly for testing purposes, but does allow for different implementations without changing the algorithm.
type bufferManager[T ~[]byte] interface {
// Acquire returns the channel that contains the pool of buffers.
Acquire() <-chan T

// Release releases the buffer back to the pool for reuse/cleanup.
Release(T)

// Grow grows the number of buffers, up to the predefined max.
// It returns the total number of buffers or an error.
// No error is returned if the number of buffers has reached max.
// This is called only from the reading goroutine.
Grow() (int, error)

// Free cleans up all buffers.
Free()
}

// copyFromReader copies a source io.Reader to blob storage using concurrent uploads.
func copyFromReader[T ~[]byte](ctx context.Context, src io.Reader, dst blockWriter, options UploadStreamOptions, getBufferManager func(maxBuffers int, bufferSize int64) bufferManager[T]) (CommitBlockListResponse, error) {
func copyFromReader[T ~[]byte](ctx context.Context, src io.Reader, dst blockWriter, options UploadStreamOptions, getBufferManager func(maxBuffers int, bufferSize int64) shared.BufferManager[T]) (CommitBlockListResponse, error) {
options.setDefaults()

wg := sync.WaitGroup{} // Used to know when all outgoing blocks have finished processing
Expand Down Expand Up @@ -265,49 +247,3 @@ func (ubi uuidBlockID) WithBlockNumber(blockNumber uint32) uuidBlockID {
func (ubi uuidBlockID) ToBase64() string {
return blockID(ubi).ToBase64()
}

// mmbPool implements the bufferManager interface.
// it uses anonymous memory mapped files for buffers.
// don't use this type directly, use newMMBPool() instead.
type mmbPool struct {
buffers chan mmb
count int
max int
size int64
}

func newMMBPool(maxBuffers int, bufferSize int64) bufferManager[mmb] {
return &mmbPool{
buffers: make(chan mmb, maxBuffers),
max: maxBuffers,
size: bufferSize,
}
}

func (pool *mmbPool) Acquire() <-chan mmb {
return pool.buffers
}

func (pool *mmbPool) Grow() (int, error) {
if pool.count < pool.max {
buffer, err := newMMB(pool.size)
if err != nil {
return 0, err
}
pool.buffers <- buffer
pool.count++
}
return pool.count, nil
}

func (pool *mmbPool) Release(buffer mmb) {
pool.buffers <- buffer
}

func (pool *mmbPool) Free() {
for i := 0; i < pool.count; i++ {
buffer := <-pool.buffers
buffer.delete()
}
pool.count = 0
}
15 changes: 8 additions & 7 deletions sdk/storage/azblob/blockblob/chunkwriting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/internal/shared"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -115,19 +116,19 @@ func calcMD5(data []byte) string {

// used to track proper acquisition and closing of buffers
type bufMgrTracker struct {
inner bufferManager[mmb]
inner shared.BufferManager[shared.Mmb]

Count int // total count of allocated buffers
Freed bool // buffers were freed
}

func newBufMgrTracker(maxBuffers int, bufferSize int64) *bufMgrTracker {
return &bufMgrTracker{
inner: newMMBPool(maxBuffers, bufferSize),
inner: shared.NewMMBPool(maxBuffers, bufferSize),
}
}

func (pool *bufMgrTracker) Acquire() <-chan mmb {
func (pool *bufMgrTracker) Acquire() <-chan shared.Mmb {
return pool.inner.Acquire()
}

Expand All @@ -140,7 +141,7 @@ func (pool *bufMgrTracker) Grow() (int, error) {
return n, nil
}

func (pool *bufMgrTracker) Release(buffer mmb) {
func (pool *bufMgrTracker) Release(buffer shared.Mmb) {
pool.inner.Release(buffer)
}

Expand All @@ -161,7 +162,7 @@ func TestSlowDestCopyFrom(t *testing.T) {

errs := make(chan error, 1)
go func() {
_, err := copyFromReader(context.Background(), bytes.NewReader(bigSrc), fakeBB, UploadStreamOptions{}, func(maxBuffers int, bufferSize int64) bufferManager[mmb] {
_, err := copyFromReader(context.Background(), bytes.NewReader(bigSrc), fakeBB, UploadStreamOptions{}, func(maxBuffers int, bufferSize int64) shared.BufferManager[shared.Mmb] {
tracker = newBufMgrTracker(maxBuffers, bufferSize)
return tracker
})
Expand Down Expand Up @@ -270,7 +271,7 @@ func TestCopyFromReader(t *testing.T) {

var tracker *bufMgrTracker

_, err := copyFromReader(test.ctx, bytes.NewReader(from), fakeBB, test.o, func(maxBuffers int, bufferSize int64) bufferManager[mmb] {
_, err := copyFromReader(test.ctx, bytes.NewReader(from), fakeBB, test.o, func(maxBuffers int, bufferSize int64) shared.BufferManager[shared.Mmb] {
tracker = newBufMgrTracker(maxBuffers, bufferSize)
return tracker
})
Expand Down Expand Up @@ -322,7 +323,7 @@ func TestCopyFromReaderReadError(t *testing.T) {
reader: bytes.NewReader(make([]byte, 5*_1MiB)),
failOn: 2,
}
_, err := copyFromReader(context.Background(), &rf, fakeBB, UploadStreamOptions{}, func(maxBuffers int, bufferSize int64) bufferManager[mmb] {
_, err := copyFromReader(context.Background(), &rf, fakeBB, UploadStreamOptions{}, func(maxBuffers int, bufferSize int64) shared.BufferManager[shared.Mmb] {
tracker = newBufMgrTracker(maxBuffers, bufferSize)
return tracker
})
Expand Down
Loading

0 comments on commit 1dc804a

Please sign in to comment.