diff --git a/pkg/content/multiwriter.go b/pkg/content/multiwriter.go index 87df2d614..526587d3b 100644 --- a/pkg/content/multiwriter.go +++ b/pkg/content/multiwriter.go @@ -9,8 +9,8 @@ import ( // MultiWriterIngester an ingester that can provide a single writer or multiple writers for a single // descriptor. Useful when the target of a descriptor can have multiple items within it, e.g. a layer // that is a tar file with multiple files, each of which should go to a different stream, some of which -// should not be handled at all +// should not be handled at all. type MultiWriterIngester interface { ctrcontent.Ingester - Writers(ctx context.Context, opts ...ctrcontent.WriterOpt) (map[string]ctrcontent.Writer, error) + Writers(ctx context.Context, opts ...ctrcontent.WriterOpt) (func(string) (ctrcontent.Writer, error), error) } diff --git a/pkg/content/passthrough.go b/pkg/content/passthrough.go index b9a891d53..09b7e6ea9 100644 --- a/pkg/content/passthrough.go +++ b/pkg/content/passthrough.go @@ -64,7 +64,9 @@ func (pw *PassthroughWriter) Write(p []byte) (n int, err error) { } func (pw *PassthroughWriter) Close() error { - pw.pipew.Close() + if pw.pipew != nil { + pw.pipew.Close() + } pw.writer.Close() return nil } @@ -82,9 +84,13 @@ func (pw *PassthroughWriter) Digest() digest.Digest { // Commit always closes the writer, even on error. // ErrAlreadyExists aborts the writer. func (pw *PassthroughWriter) Commit(ctx context.Context, size int64, expected digest.Digest, opts ...content.Opt) error { - pw.pipew.Close() + if pw.pipew != nil { + pw.pipew.Close() + } err := <-pw.done - pw.reader.Close() + if pw.reader != nil { + pw.reader.Close() + } if err != nil && err != io.EOF { return err } @@ -152,10 +158,9 @@ type PassthroughMultiWriter struct { done chan error startedAt time.Time updatedAt time.Time - ref string } -func NewPassthroughMultiWriter(writers []content.Writer, f func(r io.Reader, w []io.Writer, done chan<- error), opts ...WriterOpt) content.Writer { +func NewPassthroughMultiWriter(writers func(name string) (content.Writer, error), f func(r io.Reader, getwriter func(name string) io.Writer, done chan<- error), opts ...WriterOpt) content.Writer { // process opts for default wOpts := DefaultWriterOpts() for _, opt := range opts { @@ -164,36 +169,38 @@ func NewPassthroughMultiWriter(writers []content.Writer, f func(r io.Reader, w [ } } - var pws []*PassthroughWriter r, w := io.Pipe() - for _, writer := range writers { - pws = append(pws, &PassthroughWriter{ + + pmw := &PassthroughMultiWriter{ + startedAt: time.Now(), + updatedAt: time.Now(), + done: make(chan error, 1), + digester: digest.Canonical.Digester(), + hash: wOpts.InputHash, + pipew: w, + reader: r, + } + + // get our output writers + getwriter := func(name string) io.Writer { + writer, err := writers(name) + if err != nil || writer == nil { + return nil + } + pw := &PassthroughWriter{ writer: writer, - pipew: w, digester: digest.Canonical.Digester(), underlyingWriter: &underlyingWriter{ writer: writer, digester: digest.Canonical.Digester(), hash: wOpts.OutputHash, }, - reader: r, - hash: wOpts.InputHash, done: make(chan error, 1), - }) - } - - pmw := &PassthroughMultiWriter{ - writers: pws, - startedAt: time.Now(), - updatedAt: time.Now(), - done: make(chan error, 1), - } - // get our output writers - var uws []io.Writer - for _, uw := range pws { - uws = append(uws, uw.underlyingWriter) + } + pmw.writers = append(pmw.writers, pw) + return pw.underlyingWriter } - go f(r, uws, pmw.done) + go f(r, getwriter, pmw.done) return pmw } @@ -230,7 +237,9 @@ func (pmw *PassthroughMultiWriter) Digest() digest.Digest { func (pmw *PassthroughMultiWriter) Commit(ctx context.Context, size int64, expected digest.Digest, opts ...content.Opt) error { pmw.pipew.Close() err := <-pmw.done - pmw.reader.Close() + if pmw.reader != nil { + pmw.reader.Close() + } if err != nil && err != io.EOF { return err } diff --git a/pkg/content/passthrough_test.go b/pkg/content/passthrough_test.go index 36bcdd220..36471646c 100644 --- a/pkg/content/passthrough_test.go +++ b/pkg/content/passthrough_test.go @@ -1,9 +1,11 @@ package content_test import ( + "bytes" "context" "fmt" "io" + "math/rand" "testing" ctrcontent "github.com/containerd/containerd/content" @@ -115,3 +117,77 @@ func TestPassthroughWriter(t *testing.T) { } } } + +func TestPassthroughMultiWriter(t *testing.T) { + // pass through function that selects one of two outputs + var ( + b1, b2 bytes.Buffer + name1, name2 = "I am name 01", "I am name 02" // each of these is 12 bytes + data1, data2 = make([]byte, 500), make([]byte, 500) + ) + rand.Read(data1) + rand.Read(data2) + combined := append([]byte(name1), data1...) + combined = append(combined, []byte(name2)...) + combined = append(combined, data2...) + f := func(r io.Reader, getwriter func(name string) io.Writer, done chan<- error) { + var ( + err error + ) + // test is done rather simply, with a single 1024 byte chunk, split into 2x512 data streams, each of which is + // 12 bytes of name and 500 bytes of data + b := make([]byte, 1024) + _, err = r.Read(b) + if err != nil && err != io.EOF { + t.Fatalf("data read error: %v", err) + } + + // get the names and data for each + n1, n2 := string(b[0:12]), string(b[512+0:512+12]) + w1, w2 := getwriter(n1), getwriter(n2) + if _, err := w1.Write(b[12:512]); err != nil { + t.Fatalf("w1 write error: %v", err) + } + if _, err := w2.Write(b[512+12 : 1024]); err != nil { + t.Fatalf("w2 write error: %v", err) + } + done <- err + } + + var ( + opts = []content.WriterOpt{content.WithInputHash(testContentHash), content.WithOutputHash(modifiedContentHash)} + hash = testContentHash + ) + ctx := context.Background() + writers := func(name string) (ctrcontent.Writer, error) { + switch name { + case name1: + return content.NewIoContentWriter(&b1), nil + case name2: + return content.NewIoContentWriter(&b2), nil + } + return nil, fmt.Errorf("unknown name %s", name) + } + writer := content.NewPassthroughMultiWriter(writers, f, opts...) + n, err := writer.Write(combined) + if err != nil { + t.Fatalf("unexpected error on Write: %v", err) + } + if n != len(combined) { + t.Fatalf("wrote %d bytes instead of %d", n, len(combined)) + } + if err := writer.Commit(ctx, testDescriptor.Size, hash); err != nil { + t.Errorf("unexpected error on Commit: %v", err) + } + if digest := writer.Digest(); digest != hash { + t.Errorf("mismatched digest: actual %v, expected %v", digest, hash) + } + + // make sure the data is what we expected + if !bytes.Equal(data1, b1.Bytes()) { + t.Errorf("b1 data1 did not match") + } + if !bytes.Equal(data2, b2.Bytes()) { + t.Errorf("b2 data2 did not match") + } +} diff --git a/pkg/content/untar.go b/pkg/content/untar.go index 729a318af..805a4677d 100644 --- a/pkg/content/untar.go +++ b/pkg/content/untar.go @@ -72,10 +72,9 @@ func NewUntarWriter(writer content.Writer, opts ...WriterOpt) content.Writer { } // NewUntarWriterByName wrap multiple writers with an untar, so that the stream is untarred and passed -// to the appropriate writer, based on the filename. If a filename is not found, it will not pass it -// to any writer. The filename "" will handle any stream that does not have a specific filename; use -// it for the default of a single file in a tar stream. -func NewUntarWriterByName(writers map[string]content.Writer, opts ...WriterOpt) content.Writer { +// to the appropriate writer, based on the filename. If a filename is not found, it is up to the called func +// to determine how to process it. +func NewUntarWriterByName(writers func(string) (content.Writer, error), opts ...WriterOpt) content.Writer { // process opts for default wOpts := DefaultWriterOpts() for _, opt := range opts { @@ -84,15 +83,8 @@ func NewUntarWriterByName(writers map[string]content.Writer, opts ...WriterOpt) } } - // construct an array of content.Writer - nameToIndex := map[string]int{} - var writerSlice []content.Writer - for name, writer := range writers { - writerSlice = append(writerSlice, writer) - nameToIndex[name] = len(writerSlice) - 1 - } // need a PassthroughMultiWriter here - return NewPassthroughMultiWriter(writerSlice, func(r io.Reader, ws []io.Writer, done chan<- error) { + return NewPassthroughMultiWriter(writers, func(r io.Reader, getwriter func(name string) io.Writer, done chan<- error) { tr := tar.NewReader(r) var err error for { @@ -109,13 +101,11 @@ func NewUntarWriterByName(writers map[string]content.Writer, opts ...WriterOpt) } // get the filename filename := header.Name - index, ok := nameToIndex[filename] - if !ok { - index, ok = nameToIndex[""] - if !ok { - // we did not find this file or the wildcard, so do not process this file - continue - } + + // get the writer for this filename + w := getwriter(filename) + if w == nil { + continue } // write out the untarred data @@ -133,8 +123,8 @@ func NewUntarWriterByName(writers map[string]content.Writer, opts ...WriterOpt) if n > len(b) { l = len(b) } - if _, err2 := ws[index].Write(b[:l]); err2 != nil { - err = fmt.Errorf("UntarWriter error writing to underlying writer at index %d for name '%s': %v", index, filename, err2) + if _, err2 := w.Write(b[:l]); err2 != nil { + err = fmt.Errorf("UntarWriter error writing to underlying writer at for name '%s': %v", filename, err2) break } if err == io.EOF {