Skip to content

Commit

Permalink
Add compressed writes to cas.go. (#240)
Browse files Browse the repository at this point in the history
This follows the current tentative API being worked on in
bazelbuild/remote-apis#168. While there's technically room for it to
change, it has reached a somewhat stable point worth implementing.
  • Loading branch information
rubensf authored Nov 24, 2020
1 parent 40400d5 commit d907a15
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 123 deletions.
8 changes: 4 additions & 4 deletions go/pkg/chunker/chunker.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ var IOBufferSize = 10 * 1024 * 1024
var ErrEOF = errors.New("ErrEOF")

// Compressor for full blobs
var fullCompressor, _ = zstd.NewWriter(nil)
// It is *only* thread-safe for EncodeAll calls and should not be used for streamed compression.
// While we avoid sending 0 len blobs, we do want to create zero len compressed blobs if
// necessary.
var fullCompressor, _ = zstd.NewWriter(nil, zstd.WithZeroFrames(true))

// Chunker can be used to chunk an input into uploadable-size byte slices.
// A single Chunker is NOT thread-safe; it should be used by a single uploader thread.
Expand All @@ -43,9 +46,6 @@ type Chunker struct {
}

func New(ue *uploadinfo.Entry, compressed bool, chunkSize int) (*Chunker, error) {
if compressed {
return nil, errors.New("compression is not supported yet")
}
if chunkSize < 1 {
chunkSize = DefaultChunkSize
}
Expand Down
21 changes: 21 additions & 0 deletions go/pkg/client/cas.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import (
log "github.com/golang/glog"
)

// DefaultCompressedBytestreamThreshold is the default threshold for transferring blobs compressed on ByteStream.Write RPCs.
const DefaultCompressedBytestreamThreshold = 1024

const logInterval = 25

type requestMetadata struct {
Expand Down Expand Up @@ -858,6 +861,13 @@ func (c *Client) ResourceNameWrite(hash string, sizeBytes int64) string {
return fmt.Sprintf("%s/uploads/%s/blobs/%s/%d", c.InstanceName, uuid.New(), hash, sizeBytes)
}

// ResourceNameCompressedWrite generates a valid write resource name.
// TODO(rubensf): Converge compressor to proto in https://github.com/bazelbuild/remote-apis/pull/168 once
// that gets merged in.
func (c *Client) ResourceNameCompressedWrite(hash string, sizeBytes int64) string {
return fmt.Sprintf("%s/uploads/%s/compressed-blobs/zstd/%s/%d", c.InstanceName, uuid.New(), hash, sizeBytes)
}

// GetDirectoryTree returns the entire directory tree rooted at the given digest (which must target
// a Directory stored in the CAS).
func (c *Client) GetDirectoryTree(ctx context.Context, d *repb.Digest) (result []*repb.Directory, err error) {
Expand Down Expand Up @@ -1377,3 +1387,14 @@ func (c *Client) DownloadFiles(ctx context.Context, execRoot string, outputs map
}
return nil
}

func (c *Client) shouldCompress(sizeBytes int64) bool {
return int64(c.CompressedBytestreamThreshold) >= 0 && int64(c.CompressedBytestreamThreshold) <= sizeBytes
}

func (c *Client) writeRscName(dg digest.Digest) string {
if c.shouldCompress(dg.Size) {
return c.ResourceNameCompressedWrite(dg.Hash, dg.Size)
}
return c.ResourceNameWrite(dg.Hash, dg.Size)
}
153 changes: 76 additions & 77 deletions go/pkg/client/cas_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"testing"
"time"

"github.com/bazelbuild/remote-apis-sdks/go/pkg/chunker"
"github.com/bazelbuild/remote-apis-sdks/go/pkg/client"
"github.com/bazelbuild/remote-apis-sdks/go/pkg/digest"
"github.com/bazelbuild/remote-apis-sdks/go/pkg/fakes"
Expand Down Expand Up @@ -287,23 +286,26 @@ func TestWrite(t *testing.T) {
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
gotDg, err := c.WriteBlob(ctx, tc.blob)
if err != nil {
t.Errorf("c.WriteBlob(ctx, blob) gave error %s, wanted nil", err)
}
if fake.Err != nil {
t.Errorf("c.WriteBlob(ctx, blob) caused the server to return error %s (possibly unseen by c)", fake.Err)
}
if !bytes.Equal(tc.blob, fake.Buf) {
t.Errorf("c.WriteBlob(ctx, blob) had diff on blobs, want %v, got %v:", tc.blob, fake.Buf)
}
dg := digest.NewFromBlob(tc.blob)
if dg != gotDg {
t.Errorf("c.WriteBlob(ctx, blob) had diff on digest returned (want %s, got %s)", dg, gotDg)
}
})
for _, cmp := range []client.CompressedBytestreamThreshold{-1, 0} {
for _, tc := range tests {
t.Run(fmt.Sprintf("%s - CompressionThresh:%d", tc.name, cmp), func(t *testing.T) {
cmp.Apply(c)
gotDg, err := c.WriteBlob(ctx, tc.blob)
if err != nil {
t.Errorf("c.WriteBlob(ctx, blob) gave error %s, wanted nil", err)
}
if fake.Err != nil {
t.Errorf("c.WriteBlob(ctx, blob) caused the server to return error %s (possibly unseen by c)", fake.Err)
}
if !bytes.Equal(tc.blob, fake.Buf) {
t.Errorf("c.WriteBlob(ctx, blob) had diff on blobs, want %v, got %v:", tc.blob, fake.Buf)
}
dg := digest.NewFromBlob(tc.blob)
if dg != gotDg {
t.Errorf("c.WriteBlob(ctx, blob) had diff on digest returned (want %s, got %s)", dg, gotDg)
}
})
}
}
}

Expand Down Expand Up @@ -712,74 +714,71 @@ func TestUpload(t *testing.T) {

for _, ub := range []client.UseBatchOps{false, true} {
for _, uo := range []client.UnifiedCASOps{false, true} {
t.Run(fmt.Sprintf("UsingBatch:%t,UnifiedCASOps:%t", ub, uo), func(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
e, cleanup := fakes.NewTestEnv(t)
defer cleanup()
fake := e.Server.CAS
c := e.Client.GrpcClient
client.CASConcurrency(defaultCASConcurrency).Apply(c)
ub.Apply(c)
uo.Apply(c)
if tc.concurrency > 0 {
tc.concurrency.Apply(c)
}

present := make(map[digest.Digest]bool)
for _, blob := range tc.present {
fake.Put(blob)
present[digest.NewFromBlob(blob)] = true
}
var input []*uploadinfo.Entry
for _, blob := range tc.input {
input = append(input, uploadinfo.EntryFromBlob(blob))
}
for _, cmp := range []client.CompressedBytestreamThreshold{-1, 0} {
t.Run(fmt.Sprintf("UsingBatch:%t,UnifiedCASOps:%t,CompressionThresh:%d", ub, uo, cmp), func(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctx := context.Background()
e, cleanup := fakes.NewTestEnv(t)
defer cleanup()
fake := e.Server.CAS
c := e.Client.GrpcClient
client.CASConcurrency(defaultCASConcurrency).Apply(c)
cmp.Apply(c)
ub.Apply(c)
uo.Apply(c)
if tc.concurrency > 0 {
tc.concurrency.Apply(c)
}

missing, err := c.UploadIfMissing(ctx, input...)
if err != nil {
t.Errorf("c.UploadIfMissing(ctx, input) gave error %v, expected nil", err)
}
present := make(map[digest.Digest]bool)
for _, blob := range tc.present {
fake.Put(blob)
present[digest.NewFromBlob(blob)] = true
}
var input []*uploadinfo.Entry
for _, blob := range tc.input {
input = append(input, uploadinfo.EntryFromBlob(blob))
}

missingSet := make(map[digest.Digest]struct{})
for _, dg := range missing {
missingSet[dg] = struct{}{}
}
for _, ue := range input {
dg := ue.Digest
ch, err := chunker.New(ue, false, int(c.ChunkMaxSize))
missing, err := c.UploadIfMissing(ctx, input...)
if err != nil {
t.Fatalf("chunker.New(ue): failed to create chunker from UploadEntry: %v", err)
t.Errorf("c.UploadIfMissing(ctx, input) gave error %v, expected nil", err)
}
blob, err := ch.FullData()
if err != nil {
t.Errorf("ch.FullData() returned an error: %v", err)

missingSet := make(map[digest.Digest]struct{})
for _, dg := range missing {
missingSet[dg] = struct{}{}
}
if present[dg] {
if fake.BlobWrites(dg) > 0 {
t.Errorf("blob %v with digest %s was uploaded even though it was already present in the CAS", blob, dg)

for i, ue := range input {
dg := ue.Digest
blob := tc.input[i]
if present[dg] {
if fake.BlobWrites(dg) > 0 {
t.Errorf("blob %v with digest %s was uploaded even though it was already present in the CAS", blob, dg)
}
if _, ok := missingSet[dg]; ok {
t.Errorf("Stats said that blob %v with digest %s was missing in the CAS", blob, dg)
}
continue
}
if _, ok := missingSet[dg]; ok {
t.Errorf("Stats said that blob %v with digest %s was missing in the CAS", blob, dg)
if gotBlob, ok := fake.Get(dg); !ok {
t.Errorf("blob %v with digest %s was not uploaded, expected it to be present in the CAS", blob, dg)
} else if !bytes.Equal(blob, gotBlob) {
t.Errorf("blob digest %s had diff on uploaded blob: want %v, got %v", dg, blob, gotBlob)
}
if _, ok := missingSet[dg]; !ok {
t.Errorf("Stats said that blob %v with digest %s was present in the CAS", blob, dg)
}
continue
}
if gotBlob, ok := fake.Get(dg); !ok {
t.Errorf("blob %v with digest %s was not uploaded, expected it to be present in the CAS", blob, dg)
} else if !bytes.Equal(blob, gotBlob) {
t.Errorf("blob digest %s had diff on uploaded blob: want %v, got %v", dg, blob, gotBlob)
}
if _, ok := missingSet[dg]; !ok {
t.Errorf("Stats said that blob %v with digest %s was present in the CAS", blob, dg)
if fake.MaxConcurrency() > defaultCASConcurrency {
t.Errorf("CAS concurrency %v was higher than max %v", fake.MaxConcurrency(), defaultCASConcurrency)
}
}
if fake.MaxConcurrency() > defaultCASConcurrency {
t.Errorf("CAS concurrency %v was higher than max %v", fake.MaxConcurrency(), defaultCASConcurrency)
}
})
}
})
})
}
})
}
}
}
}
Expand Down
57 changes: 35 additions & 22 deletions go/pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ type Client struct {
StartupCapabilities StartupCapabilities
// ChunkMaxSize is maximum chunk size to use for CAS uploads/downloads.
ChunkMaxSize ChunkMaxSize
// CompressedBytestreamThreshold is the threshold in bytes for which blobs are read and written
// compressed. Use 0 for all writes being compressed, and a negative number for all writes being
// uncompressed. TODO(rubensf): Make sure this will throw an error if the server doesn't support compression,
// pending https://github.com/bazelbuild/remote-apis/pull/168 being submitted.
CompressedBytestreamThreshold CompressedBytestreamThreshold
// MaxBatchDigests is maximum amount of digests to batch in batched operations.
MaxBatchDigests MaxBatchDigests
// MaxBatchSize is maximum size in bytes of a batch request for batch operations.
Expand Down Expand Up @@ -145,6 +150,13 @@ func (s ChunkMaxSize) Apply(c *Client) {
c.ChunkMaxSize = s
}

type CompressedBytestreamThreshold int64

// Apply sets the client's maximal chunk size s.
func (s CompressedBytestreamThreshold) Apply(c *Client) {
c.CompressedBytestreamThreshold = s
}

// UtilizeLocality is to specify whether client downloads files utilizing disk access locality.
type UtilizeLocality bool

Expand Down Expand Up @@ -460,28 +472,29 @@ func NewClient(ctx context.Context, instanceName string, params DialParams, opts
return nil, err
}
client := &Client{
InstanceName: instanceName,
actionCache: regrpc.NewActionCacheClient(casConn),
byteStream: bsgrpc.NewByteStreamClient(casConn),
cas: regrpc.NewContentAddressableStorageClient(casConn),
execution: regrpc.NewExecutionClient(conn),
operations: opgrpc.NewOperationsClient(conn),
rpcTimeouts: DefaultRPCTimeouts,
Connection: conn,
CASConnection: casConn,
ChunkMaxSize: chunker.DefaultChunkSize,
MaxBatchDigests: DefaultMaxBatchDigests,
MaxBatchSize: DefaultMaxBatchSize,
DirMode: DefaultDirMode,
ExecutableMode: DefaultExecutableMode,
RegularMode: DefaultRegularMode,
useBatchOps: true,
StartupCapabilities: true,
casConcurrency: DefaultCASConcurrency,
casUploaders: semaphore.NewWeighted(DefaultCASConcurrency),
casDownloaders: semaphore.NewWeighted(DefaultCASConcurrency),
casUploads: make(map[digest.Digest]*uploadState),
Retrier: RetryTransient(),
InstanceName: instanceName,
actionCache: regrpc.NewActionCacheClient(casConn),
byteStream: bsgrpc.NewByteStreamClient(casConn),
cas: regrpc.NewContentAddressableStorageClient(casConn),
execution: regrpc.NewExecutionClient(conn),
operations: opgrpc.NewOperationsClient(conn),
rpcTimeouts: DefaultRPCTimeouts,
Connection: conn,
CASConnection: casConn,
CompressedBytestreamThreshold: DefaultCompressedBytestreamThreshold,
ChunkMaxSize: chunker.DefaultChunkSize,
MaxBatchDigests: DefaultMaxBatchDigests,
MaxBatchSize: DefaultMaxBatchSize,
DirMode: DefaultDirMode,
ExecutableMode: DefaultExecutableMode,
RegularMode: DefaultRegularMode,
useBatchOps: true,
StartupCapabilities: true,
casConcurrency: DefaultCASConcurrency,
casUploaders: semaphore.NewWeighted(DefaultCASConcurrency),
casDownloaders: semaphore.NewWeighted(DefaultCASConcurrency),
casUploads: make(map[digest.Digest]*uploadState),
Retrier: RetryTransient(),
}
for _, o := range opts {
o.Apply(client)
Expand Down
1 change: 1 addition & 0 deletions go/pkg/fakes/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ go_library(
"@com_github_golang_glog//:go_default_library",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_golang_protobuf//ptypes:go_default_library_gen",
"@com_github_klauspost_compress//zstd:go_default_library",
"@com_github_pborman_uuid//:go_default_library",
"@go_googleapis//google/bytestream:bytestream_go_proto",
"@go_googleapis//google/longrunning:longrunning_go_proto",
Expand Down
Loading

0 comments on commit d907a15

Please sign in to comment.