From 9366f9da4cfcfde3ce852e8092cfd5b96440fc1d Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 22 Nov 2023 13:07:20 -0700 Subject: [PATCH 001/195] Generate large CSV --- .gitignore | 1 + drivers/csv/csv_test.go | 78 +++++++++++++++++++++++++++++++++++++++++ drivers/csv/ingest.go | 3 ++ 3 files changed, 82 insertions(+) diff --git a/.gitignore b/.gitignore index 70dec3217..64206612d 100644 --- a/.gitignore +++ b/.gitignore @@ -54,3 +54,4 @@ goreleaser-test.sh /cli/test.db /*.db /.CHANGELOG.delta.md +/drivers/csv/testdata/payment-large.csv diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index 6b0e73055..f429ab47d 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -2,7 +2,12 @@ package csv_test import ( "context" + "golang.org/x/text/language" + "golang.org/x/text/message" + "math/rand" + "os" "path/filepath" + "strconv" "testing" "time" @@ -11,6 +16,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + stdcsv "encoding/csv" + "github.com/neilotoole/sq/cli/testrun" "github.com/neilotoole/sq/drivers/csv" "github.com/neilotoole/sq/libsq/core/kind" @@ -331,3 +338,74 @@ func TestDatetime(t *testing.T) { }) } } + +// TestIngestLargeCSV generates a large CSV file. +// At count = 5000000, the generated file is ~500MB. +func TestGenerateLargeCSV(t *testing.T) { + t.Skip() + const count = 5000000 // Generates ~500MB file + start := time.Now() + header := []string{ + "payment_id", + "customer_id", + "name", + "staff_id", + "rental_id", + "amount", + "payment_date", + "last_update", + } + + f, err := os.OpenFile( + "testdata/payment-large.csv", + os.O_CREATE|os.O_WRONLY|os.O_TRUNC, + 0600, + ) + require.NoError(t, err) + t.Cleanup(func() { _ = f.Close() }) + + w := stdcsv.NewWriter(f) + require.NoError(t, w.Write(header)) + + rec := make([]string, len(header)) + amount := decimal.New(50000, -2) + paymentUTC := time.Now().UTC() + lastUpdateUTC := time.Now().UTC() + p := message.NewPrinter(language.English) + for i := 0; i < count; i++ { + if i%100000 == 0 { + // Flush occasionally + w.Flush() + } + + rec[0] = strconv.Itoa(i + 1) // payment id, always unique + rec[1] = strconv.Itoa(rand.Intn(100)) // customer_id, one of 100 customers + rec[2] = "Alice " + rec[1] // name + rec[3] = strconv.Itoa(rand.Intn(10)) // staff_id + rec[4] = strconv.Itoa(i + 3) // rental_id, always unique + f64 := amount.InexactFloat64() + rec[5] = p.Sprintf("%.2f", f64) // amount + amount = amount.Add(decimal.New(33, -2)) + rec[6] = timez.TimestampUTC(paymentUTC) // payment_date + paymentUTC = paymentUTC.Add(time.Minute) + rec[7] = timez.TimestampUTC(lastUpdateUTC) // last_update + lastUpdateUTC = lastUpdateUTC.Add(time.Minute + time.Second) + err = w.Write(rec) + require.NoError(t, err) + } + + w.Flush() + require.NoError(t, w.Error()) + require.NoError(t, f.Close()) + + fi, err := os.Stat(f.Name()) + require.NoError(t, err) + + t.Logf( + "Wrote %s records in %s, total size %s, to: %s", + p.Sprintf("%d", count), + time.Since(start).Round(time.Millisecond), + stringz.ByteSized(fi.Size(), 1, ""), + f.Name(), + ) +} diff --git a/drivers/csv/ingest.go b/drivers/csv/ingest.go index 3bd27815b..fc2883e0f 100644 --- a/drivers/csv/ingest.go +++ b/drivers/csv/ingest.go @@ -5,6 +5,7 @@ import ( "encoding/csv" "errors" "io" + "time" "unicode/utf8" "github.com/neilotoole/sq/libsq" @@ -51,6 +52,7 @@ Possible values are: comma, space, pipe, tab, colon, semi, period.`, // ingestCSV loads the src CSV data into scratchDB. func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFunc, scratchPool driver.Pool) error { log := lg.FromContext(ctx) + startUTC := time.Now().UTC() var err error var r io.ReadCloser @@ -140,6 +142,7 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu log.Debug("Inserted rows", lga.Count, inserted, + lga.Elapsed, time.Since(startUTC).Round(time.Millisecond), lga.Target, source.Target(scratchPool.Source(), tblDef.Name), ) return nil From 2d1aad83c33aace758106aee4764eed92c236a68 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 22 Nov 2023 20:05:06 -0700 Subject: [PATCH 002/195] wip: broken --- cli/options.go | 1 + cli/run.go | 2 +- drivers/csv/csv.go | 54 ++++- drivers/csv/csv_test.go | 9 +- drivers/json/json.go | 2 +- drivers/sqlite3/sqlite3.go | 31 ++- drivers/userdriver/userdriver.go | 2 +- drivers/userdriver/xmlud/xmlimport_test.go | 9 +- drivers/xlsx/xlsx.go | 2 +- libsq/core/ioz/checksum.go | 99 +++++++++ libsq/core/ioz/ioz.go | 6 + libsq/core/ioz/ioz_test.go | 36 ++++ libsq/core/lg/lga/lga.go | 1 + libsq/core/options/options.go | 20 ++ libsq/driver/driver.go | 230 +++++++++++++++++++-- libsq/driver/ingest.go | 11 + libsq/driver/scratch.go | 11 - libsq/pipeline.go | 6 +- libsq/source/files.go | 76 +++++-- libsq/source/handle.go | 6 + libsq/source/source.go | 20 ++ testh/testh.go | 2 +- 22 files changed, 546 insertions(+), 90 deletions(-) create mode 100644 libsq/core/ioz/checksum.go delete mode 100644 libsq/driver/scratch.go diff --git a/cli/options.go b/cli/options.go index 780061968..9ab769955 100644 --- a/cli/options.go +++ b/cli/options.go @@ -166,6 +166,7 @@ func RegisterDefaultOpts(reg *options.Registry) { driver.OptTuningRecChanSize, OptTuningFlushThreshold, driver.OptIngestHeader, + driver.OptIngestCache, driver.OptIngestColRename, driver.OptIngestSampleSize, csv.OptDelim, diff --git a/cli/run.go b/cli/run.go index 6e37dd0d7..6f5c1c5c3 100644 --- a/cli/run.go +++ b/cli/run.go @@ -157,7 +157,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { ru.DriverRegistry = driver.NewRegistry(log) dr := ru.DriverRegistry - ru.Pools = driver.NewPools(log, dr, scratchSrcFunc) + ru.Pools = driver.NewPools(log, dr, ru.Files, scratchSrcFunc) ru.Cleanup.AddC(ru.Pools) dr.AddProvider(sqlite3.Type, &sqlite3.Provider{Log: log}) diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index bb191811a..60a3d2f00 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -10,6 +10,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" @@ -66,25 +67,62 @@ func (d *driveri) DriverMetadata() driver.Metadata { // Open implements driver.PoolOpener. func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { - lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) + log := lg.FromContext(ctx) + log.Debug(lgm.OpenSrc, lga.Src, src) - pool := &pool{ + p := &pool{ log: d.log, src: src, files: d.files, } - var err error - pool.impl, err = d.scratcher.OpenScratch(ctx, src.Handle) - if err != nil { - return nil, err + allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) + // impl, err := d.scratcher.OpenScratchFor(ctx, src, allowCache) + + ingestFn := func(ctx context.Context, destPool driver.Pool) error { + return ingestCSV(ctx, src, d.files.OpenFunc(src), destPool) } - if err = ingestCSV(ctx, src, d.files.OpenFunc(src), pool.impl); err != nil { + backingPool, err := d.scratcher.OpenIngest(ctx, src, ingestFn, allowCache) + if err != nil { return nil, err } - return pool, nil + p.impl = backingPool + + // + //var err error + //if { + // // Caching is enabled, let's see if we can find a cached copy. + // var foundCached bool + // p.impl, foundCached, err = d.scratcher.OpenCachedFor(ctx, src) + // if err != nil { + // return nil, err + // } + // if foundCached { + // log.Debug("Cache HIT: found cached copy of source", + // lga.Src, src, "cached", p.impl.Source(), + // ) + // return p, nil + // } + // + // log.Debug("Cache MISS: no cache for source", lga.Src, src) + //} + // + //if p.impl == nil { + // p.impl, err = d.scratcher.OpenScratchFor(ctx, src) + // if err != nil { + // return nil, err + // } + //} + // + //if err = ingestCSV(ctx, src, d.files.OpenFunc(src), p.impl); err != nil { + // return nil, err + //} + + // FIXME: We really should be writing the checksum after ingestCSV happens + + return p, nil } // Truncate implements driver.Driver. diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index f429ab47d..41fdff298 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -2,8 +2,7 @@ package csv_test import ( "context" - "golang.org/x/text/language" - "golang.org/x/text/message" + stdcsv "encoding/csv" "math/rand" "os" "path/filepath" @@ -15,8 +14,8 @@ import ( "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - stdcsv "encoding/csv" + "golang.org/x/text/language" + "golang.org/x/text/message" "github.com/neilotoole/sq/cli/testrun" "github.com/neilotoole/sq/drivers/csv" @@ -359,7 +358,7 @@ func TestGenerateLargeCSV(t *testing.T) { f, err := os.OpenFile( "testdata/payment-large.csv", os.O_CREATE|os.O_WRONLY|os.O_TRUNC, - 0600, + 0o600, ) require.NoError(t, err) t.Cleanup(func() { _ = f.Close() }) diff --git a/drivers/json/json.go b/drivers/json/json.go index af77c92c9..8e04c5db1 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -102,7 +102,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er return nil, err } - p.impl, err = d.scratcher.OpenScratch(ctx, src.Handle) + p.impl, err = d.scratcher.OpenScratchFor(ctx, src) if err != nil { lg.WarnIfCloseError(d.log, lgm.CloseFileReader, r) lg.WarnIfFuncError(d.log, lgm.CloseDB, p.clnup.Run) diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index 2517d7896..22ea8c7b9 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -1020,29 +1020,26 @@ func (p *pool) Close() error { return err } -// NewScratchSource returns a new scratch src. Effectively this -// function creates a new sqlite db file in the temp dir, and -// src points at this file. The returned clnup func will delete -// the file. -func NewScratchSource(ctx context.Context, name string) (src *source.Source, clnup func() error, err error) { +var _ driver.ScratchSrcFunc = NewScratchSource + +// NewScratchSource returns a new scratch src. The supplied fpath +// must be the absolute path to the location to create the SQLite DB file, +// typically in the user cache dir. +// The returned clnup func will delete the file. +func NewScratchSource(ctx context.Context, fpath string) (src *source.Source, clnup func() error, err error) { log := lg.FromContext(ctx) - name = stringz.SanitizeAlphaNumeric(name, '_') - dir, file, err := source.TempDirFile(name + ".sqlite") - if err != nil { - return nil, nil, err - } - log.Debug("Created sqlite3 scratchdb data file", lga.Path, file) + log.Debug("Created sqlite3 scratchdb data file", lga.Path, fpath) src = &source.Source{ Type: Type, Handle: source.ScratchHandle, - Location: Prefix + file, + Location: Prefix + fpath, } fn := func() error { - log.Debug("Deleting sqlite3 scratchdb file", lga.Src, src, lga.Path, file) - rmErr := errz.Err(os.RemoveAll(dir)) + log.Debug("Deleting sqlite3 scratchdb file", lga.Src, src, lga.Path, fpath) + rmErr := errz.Err(os.Remove(fpath)) if rmErr != nil { log.Warn("Delete sqlite3 scratchdb file", lga.Err, rmErr) } @@ -1052,9 +1049,11 @@ func NewScratchSource(ctx context.Context, name string) (src *source.Source, cln return src, fn, nil } -// PathFromLocation returns the absolute file path -// from the source location, which should have the "sqlite3://" prefix. +// PathFromLocation returns the absolute file path from the source location, +// which should have the "sqlite3://" prefix. func PathFromLocation(src *source.Source) (string, error) { + // FIXME: Does this actually work with query params in the path? + // Probably not? Maybe refactor use dburl.Parse or such. if src.Type != Type { return "", errz.Errorf("driver {%s} does not support {%s}", Type, src.Type) } diff --git a/drivers/userdriver/userdriver.go b/drivers/userdriver/userdriver.go index 7338488b6..c9bff94d6 100644 --- a/drivers/userdriver/userdriver.go +++ b/drivers/userdriver/userdriver.go @@ -93,7 +93,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er defer lg.WarnIfCloseError(d.log, lgm.CloseFileReader, r) - scratchDB, err := d.scratcher.OpenScratch(ctx, src.Handle) + scratchDB, err := d.scratcher.OpenScratchFor(ctx, src) if err != nil { return nil, err } diff --git a/drivers/userdriver/xmlud/xmlimport_test.go b/drivers/userdriver/xmlud/xmlimport_test.go index d35687b33..e2bc680ca 100644 --- a/drivers/userdriver/xmlud/xmlimport_test.go +++ b/drivers/userdriver/xmlud/xmlimport_test.go @@ -4,6 +4,9 @@ import ( "bytes" "testing" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/drivertype" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -31,7 +34,8 @@ func TestImport_Ppl(t *testing.T) { require.Equal(t, driverPpl, udDef.Name) require.Equal(t, xmlud.Genre, udDef.Genre) - scratchDB, err := th.Pools().OpenScratch(th.Context, "ppl") + src := &source.Source{Handle: "@ppl_" + stringz.Uniq8(), Type: drivertype.None} + scratchDB, err := th.Pools().OpenScratchFor(th.Context, src) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, scratchDB.Close()) @@ -76,7 +80,8 @@ func TestImport_RSS(t *testing.T) { require.Equal(t, driverRSS, udDef.Name) require.Equal(t, xmlud.Genre, udDef.Genre) - scratchDB, err := th.Pools().OpenScratch(th.Context, "rss") + src := &source.Source{Handle: "@rss_" + stringz.Uniq8(), Type: drivertype.None} + scratchDB, err := th.Pools().OpenScratchFor(th.Context, src) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, scratchDB.Close()) diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index 71d290b67..d32d14179 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -61,7 +61,7 @@ func (d *Driver) DriverMetadata() driver.Metadata { func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) - scratchPool, err := d.scratcher.OpenScratch(ctx, src.Handle) + scratchPool, err := d.scratcher.OpenScratchFor(ctx, src) if err != nil { return nil, err } diff --git a/libsq/core/ioz/checksum.go b/libsq/core/ioz/checksum.go new file mode 100644 index 000000000..35ae17d95 --- /dev/null +++ b/libsq/core/ioz/checksum.go @@ -0,0 +1,99 @@ +package ioz + +import ( + "bufio" + "bytes" + "crypto/sha256" + "fmt" + "io" + "os" + "strconv" + "strings" + + "github.com/neilotoole/sq/libsq/core/errz" +) + +// FileChecksum returns a checksum of the file at path. +// The checksum is based on the file's name, size, mode, and +// modification time. File contents are not read. +func FileChecksum(path string) (string, error) { + fi, err := os.Stat(path) + if err != nil { + return "", errz.Wrap(err, "calculate file checksum") + } + + buf := bytes.Buffer{} + buf.WriteString(fi.Name()) + buf.WriteString(strconv.FormatInt(fi.ModTime().UnixNano(), 10)) + buf.WriteString(strconv.FormatInt(fi.Size(), 10)) + buf.WriteString(strconv.FormatUint(uint64(fi.Mode()), 10)) + buf.WriteString(strconv.FormatBool(fi.IsDir())) + + sum := sha256.Sum256(buf.Bytes()) + return fmt.Sprintf("%x", sum), nil +} + +// WriteChecksum appends a checksum line to w, including +// a newline. The format is: +// +// +// da1f14c16c09bebbc452108d9ab193541f2e96515aefcb7745fee5197c343106 file.txt +// +// Use FileChecksum to calculate a checksum, and ReadChecksums +// to read this format. +func WriteChecksum(w io.Writer, sum, name string) error { + _, err := fmt.Fprintf(w, "%s %s\n", sum, name) + return errz.Err(err) +} + +// WriteChecksumFile writes a single {checksum,name} to path, overwriting +// the previous contents. +// +// See: WriteChecksum. +func WriteChecksumFile(path, sum, name string) error { + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + if err != nil { + return errz.Wrap(err, "write checksum file") + } + defer func() { _ = f.Close() }() + return WriteChecksum(f, sum, name) +} + +// ReadChecksumsFile reads a checksum file from path. +// +// See ReadChecksums for details. +func ReadChecksumsFile(path string) (map[string]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, errz.Err(err) + } + + defer func() { _ = f.Close() }() + + return ReadChecksums(f) +} + +// ReadChecksums reads checksums lines from r, returning a map +// of checksums keyed by name. Empty lines, and lines beginning +// with "#" (comments) are ignored. This function is the +// inverse of WriteChecksum. +func ReadChecksums(r io.Reader) (map[string]string, error) { + sums := map[string]string{} + + sc := bufio.NewScanner(r) + for sc.Scan() { + line := strings.TrimSpace(sc.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + parts := strings.SplitN(line, " ", 2) + if len(parts) != 2 { + return nil, errz.Errorf("invalid checksum line: %q", line) + } + + sums[parts[1]] = parts[0] + } + + return sums, errz.Wrap(sc.Err(), "read checksums") +} diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index ee00d3a6c..28952a317 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -173,3 +173,9 @@ func IsPathToRegularFile(path string) bool { return fi.Mode().IsRegular() } + +// FileAccessible returns true if path is a file that can be read. +func FileAccessible(path string) bool { + _, err := os.Stat(path) + return err == nil +} diff --git a/libsq/core/ioz/ioz_test.go b/libsq/core/ioz/ioz_test.go index 087578885..1aa1b84e2 100644 --- a/libsq/core/ioz/ioz_test.go +++ b/libsq/core/ioz/ioz_test.go @@ -1,8 +1,12 @@ package ioz_test import ( + "bytes" + "io" + "os" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/core/ioz" @@ -17,3 +21,35 @@ func TestMarshalYAML(t *testing.T) { require.NoError(t, err) require.NotNil(t, b) } + +func TestChecksums(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "sq-test-*") + require.NoError(t, err) + _, err = io.WriteString(f, "huzzah") + require.NoError(t, err) + assert.NoError(t, f.Close()) + + buf := &bytes.Buffer{} + + gotSum1, err := ioz.FileChecksum(f.Name()) + require.NoError(t, err) + t.Logf("gotSum1: %s %s", gotSum1, f.Name()) + require.NoError(t, ioz.WriteChecksum(buf, gotSum1, f.Name())) + + gotSums, err := ioz.ReadChecksums(bytes.NewReader(buf.Bytes())) + require.NoError(t, err) + require.Len(t, gotSums, 1) + require.Equal(t, gotSum1, gotSums[f.Name()]) + + // Make some changes to the file and verify that the checksums differ. + f, err = os.OpenFile(f.Name(), os.O_APPEND|os.O_WRONLY, 0o600) + require.NoError(t, err) + _, err = io.WriteString(f, "more huzzah") + require.NoError(t, err) + assert.NoError(t, f.Close()) + gotSum2, err := ioz.FileChecksum(f.Name()) + require.NoError(t, err) + t.Logf("gotSum2: %s %s", gotSum2, f.Name()) + require.NoError(t, ioz.WriteChecksum(buf, gotSum1, f.Name())) + require.NotEqual(t, gotSum1, gotSum2) +} diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index 808e65e4f..e916450ae 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -15,6 +15,7 @@ const ( Cleanup = "cleanup" DB = "db" DBType = "db_type" + Dest = "dest" Driver = "driver" DefaultTo = "default_to" Elapsed = "elapsed" diff --git a/libsq/core/options/options.go b/libsq/core/options/options.go index 282933742..5dec9ad2d 100644 --- a/libsq/core/options/options.go +++ b/libsq/core/options/options.go @@ -13,7 +13,9 @@ package options import ( + "bytes" "context" + "crypto/sha256" "fmt" "log/slog" "slices" @@ -177,6 +179,24 @@ func (o Options) Clone() Options { return o2 } +// Hash returns a SHA256 hash of o. If o is nil or empty, +// an empty string is returned. +func (o Options) Hash() string { + if len(o) == 0 { + return "" + } + + buf := bytes.Buffer{} + for k, v := range o { + buf.WriteString(k) + if v != nil { + buf.WriteString(fmt.Sprintf("%v", v)) + } + } + sum := sha256.Sum256(buf.Bytes()) + return fmt.Sprintf("%x", sum) +} + // Keys returns the sorted set of keys in o. func (o Options) Keys() []string { keys := lo.Keys(o) diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index 70af6ebd6..8cc8e5883 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -6,6 +6,7 @@ import ( "database/sql" "fmt" "log/slog" + "path/filepath" "strings" "sync" "time" @@ -13,6 +14,7 @@ import ( "github.com/neilotoole/sq/libsq/ast/render" "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" @@ -178,8 +180,17 @@ type JoinPoolOpener interface { // typically a short-lived database used as a target for loading // non-SQL data (such as CSV). type ScratchPoolOpener interface { - // OpenScratch returns a pool for scratch use. - OpenScratch(ctx context.Context, name string) (Pool, error) + // OpenScratchFor returns a pool for scratch use. + OpenScratchFor(ctx context.Context, src *source.Source) (Pool, error) + + // OpenCachedFor returns any already cached ingested pool for src. + // If no such cache, or if it's expired, false is returned. + OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bool, error) + + // OpenIngest opens a pool for src by executing ingestFn. If allowCache + // is false, ingest always occurs; if true, the cache is consulted first. + OpenIngest(ctx context.Context, src *source.Source, + ingestFn func(ctx context.Context, destPool Pool) error, allowCache bool) (Pool, error) } // Driver is the core interface that must be implemented for each type @@ -408,29 +419,40 @@ type Metadata struct { } var ( - _ PoolOpener = (*Pools)(nil) - _ JoinPoolOpener = (*Pools)(nil) + _ PoolOpener = (*Pools)(nil) + _ JoinPoolOpener = (*Pools)(nil) + _ ScratchPoolOpener = (*Pools)(nil) ) +// ScratchSrcFunc is a function that returns a scratch source. +// The caller is responsible for invoking cleanFn. +type ScratchSrcFunc func(ctx context.Context, name string) (src *source.Source, cleanFn func() error, err error) + // Pools provides a mechanism for getting Pool instances. // Note that at this time instances returned by Open are cached // and then closed by Close. This may be a bad approach. +// +// FIXME: Why not rename driver.Pools to driver.Sources? type Pools struct { log *slog.Logger drvrs Provider mu sync.Mutex scratchSrcFn ScratchSrcFunc + files *source.Files pools map[string]Pool clnup *cleanup.Cleanup } // NewPools returns a Pools instances. -func NewPools(log *slog.Logger, drvrs Provider, scratchSrcFn ScratchSrcFunc) *Pools { +func NewPools(log *slog.Logger, drvrs Provider, + files *source.Files, scratchSrcFn ScratchSrcFunc, +) *Pools { return &Pools{ log: log, drvrs: drvrs, mu: sync.Mutex{}, scratchSrcFn: scratchSrcFn, + files: files, pools: map[string]Pool{}, clnup: cleanup.New(), } @@ -479,44 +501,212 @@ func (d *Pools) Open(ctx context.Context, src *source.Source) (Pool, error) { return pool, nil } -// OpenScratch returns a scratch database instance. It is not +// OpenScratchFor returns a scratch database instance. It is not // necessary for the caller to close the returned Pool as // its Close method will be invoked by d.Close. // -// OpenScratch implements ScratchPoolOpener. -func (d *Pools) OpenScratch(ctx context.Context, name string) (Pool, error) { - const msgCloseScratch = "close scratch db" +// OpenScratchFor implements ScratchPoolOpener. +// +// REVISIT: do we really need to pass a source here? Just a string should do. +// +// FIXME: the problem is with passing src? +// +// FIXME: Add cacheAllowed bool? +func (d *Pools) OpenScratchFor(ctx context.Context, src *source.Source) (Pool, error) { + const msgCloseScratch = "Close scratch db" - scratchSrc, cleanFn, err := d.scratchSrcFn(ctx, name) + _, _, srcCacheFilepath, err := d.getCachePaths(src) if err != nil { - // if err is non-nil, cleanup is guaranteed to be nil return nil, err } - d.log.Debug("Opening scratch src", lga.Src, scratchSrc) - drvr, err := d.drvrs.DriverFor(scratchSrc.Type) + scratchSrc, cleanFn, err := d.scratchSrcFn(ctx, srcCacheFilepath) if err != nil { - lg.WarnIfFuncError(d.log, msgCloseScratch, cleanFn) + // if err is non-nil, cleanup is guaranteed to be nil return nil, err } + d.log.Debug("Opening scratch src", lga.Src, scratchSrc) - sqlDrvr, ok := drvr.(SQLDriver) - if !ok { + backingDrvr, err := d.drvrs.DriverFor(scratchSrc.Type) + if err != nil { lg.WarnIfFuncError(d.log, msgCloseScratch, cleanFn) - return nil, errz.Errorf("driver for scratch source %s is not a SQLDriver but is %T", scratchSrc.Handle, drvr) + return nil, err } var backingPool Pool - backingPool, err = sqlDrvr.Open(ctx, scratchSrc) + backingPool, err = backingDrvr.Open(ctx, scratchSrc) if err != nil { lg.WarnIfFuncError(d.log, msgCloseScratch, cleanFn) return nil, err } - d.clnup.AddE(cleanFn) + allowCache := OptIngestCache.Get(options.FromContext(ctx)) + if !allowCache { + // If the ingest cache is disabled, we add the cleanup func + // so the scratch DB is deleted when the session ends. + d.clnup.AddE(cleanFn) + } + return backingPool, nil } +// OpenIngest implements driver.ScratchPoolOpener. +func (d *Pools) OpenIngest(ctx context.Context, src *source.Source, + ingestFn func(ctx context.Context, destPool Pool) error, allowCache bool, +) (Pool, error) { + log := lg.FromContext(ctx) + + ingestFilePath, err := d.files.Filepath(src) + if err != nil { + return nil, err + } + + var checksumsPath string + var cacheDir string + if cacheDir, _, checksumsPath, err = d.getCachePaths(src); err != nil { + return nil, err + } + + log.Debug("Using cache dir", lga.Path, cacheDir) + + if allowCache { + var ( + impl Pool + foundCached bool + ) + if impl, foundCached, err = d.OpenCachedFor(ctx, src); err != nil { + return nil, err + } + if foundCached { + log.Debug("Ingest cache HIT: found cached copy of source", + lga.Src, src, "cached", impl.Source(), + ) + return impl, nil + } + + log.Debug("Ingest cache MISS: no cache for source", lga.Src, src) + } + + impl, err := d.OpenScratchFor(ctx, src) + if err != nil { + return nil, err + } + + start := time.Now() + err = ingestFn(ctx, impl) + elapsed := time.Since(start) + + if err != nil { + log.Error("Ingest failed", + lga.Src, src, lga.Dest, impl.Source(), + lga.Elapsed, elapsed, lga.Err, err, + ) + return nil, err + } + + log.Debug("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) + + if allowCache { + // Write the checksums file. + var sum string + if sum, err = ioz.FileChecksum(ingestFilePath); err != nil { + log.Warn("Failed to compute checksum for source file; caching not in effect", + lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) + return impl, nil + } + + if err = ioz.WriteChecksumFile(checksumsPath, sum, ingestFilePath); err != nil { + log.Warn("Failed to write checksum; file caching not in effect", + lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) + } + } + + return impl, nil +} + +// getCachePaths returns the paths to the cache files for src. +// There is no guarantee that these files exist, or are accessible. +// It's just the paths. +func (d *Pools) getCachePaths(src *source.Source) (dir, cacheDB, checksums string, err error) { //nolint:unparam + if dir, err = source.CacheDirFor(src); err != nil { + return "", "", "", err + } + + checksums = filepath.Join(dir, "checksums.txt") + cacheDB = filepath.Join(dir, "cached.db") + return dir, cacheDB, checksums, nil +} + +// OpenCachedFor implements ScratchPoolOpener. +func (d *Pools) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bool, error) { + _, cacheDBPath, checksumsPath, err := d.getCachePaths(src) + if err != nil { + return nil, false, err + } + + if !ioz.FileAccessible(checksumsPath) { + return nil, false, nil + } + + mChecksums, err := ioz.ReadChecksumsFile(checksumsPath) + if err != nil { + return nil, false, err + } + + drvr, err := d.drvrs.DriverFor(src.Type) + if err != nil { + return nil, false, err + } + + if drvr.DriverMetadata().IsSQL { + return nil, false, errz.Errorf("open file cache for source %s: driver {%s} is SQL, not document", + src.Handle, src.Type) + } + + srcFilepath, err := d.files.Filepath(src) + if err != nil { + return nil, false, err + } + + cachedChecksum, ok := mChecksums[srcFilepath] + if !ok { + return nil, false, nil + } + + srcChecksum, err := ioz.FileChecksum(srcFilepath) + if err != nil { + return nil, false, err + } + + if srcChecksum != cachedChecksum { + return nil, false, nil + } + + // The checksums match, so we can use the cached DB, + // if it exists. + if !ioz.FileAccessible(cacheDBPath) { + return nil, false, nil + } + + backingType, err := d.files.DriverType(ctx, cacheDBPath) + if err != nil { + return nil, false, err + } + + backingSrc := &source.Source{ + Handle: src.Handle + "_cached", + Location: cacheDBPath, + Type: backingType, + } + + backingPool, err := d.Open(ctx, backingSrc) + if err != nil { + return nil, false, errz.Wrapf(err, "open cached DB for source %s", src.Handle) + } + + return backingPool, true, nil +} + // OpenJoin opens an appropriate database for use as // a work DB for joining across sources. // @@ -535,7 +725,7 @@ func (d *Pools) OpenJoin(ctx context.Context, srcs ...*source.Source) (Pool, err } d.log.Debug("OpenJoin", "sources", strings.Join(names, ",")) - return d.OpenScratch(ctx, "joindb__"+strings.Join(names, "_")) + return d.OpenScratchFor(ctx, srcs[0]) } // Close closes d, invoking Close on any instances opened via d.Open. diff --git a/libsq/driver/ingest.go b/libsq/driver/ingest.go index 3c4cda280..765d841b9 100644 --- a/libsq/driver/ingest.go +++ b/libsq/driver/ingest.go @@ -23,6 +23,17 @@ to detect the header.`, options.TagSource, ) +// OptIngestCache specifies whether ingested data is cached or not. +var OptIngestCache = options.NewBool( + "ingest.bool", + "", + 0, + true, + "Ingest data is cached", + `Specifies whether ingested data is cached or not.`, + options.TagSource, +) + // OptIngestSampleSize specifies the number of samples that a detector // should take to determine type. var OptIngestSampleSize = options.NewInt( diff --git a/libsq/driver/scratch.go b/libsq/driver/scratch.go deleted file mode 100644 index 213cd1d44..000000000 --- a/libsq/driver/scratch.go +++ /dev/null @@ -1,11 +0,0 @@ -package driver - -import ( - "context" - - "github.com/neilotoole/sq/libsq/source" -) - -// ScratchSrcFunc is a function that returns a scratch source. -// The caller is responsible for invoking cleanFn. -type ScratchSrcFunc func(ctx context.Context, name string) (src *source.Source, cleanFn func() error, err error) diff --git a/libsq/pipeline.go b/libsq/pipeline.go index 9ed34dc24..9a724c6a4 100644 --- a/libsq/pipeline.go +++ b/libsq/pipeline.go @@ -18,9 +18,11 @@ import ( "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/libsq/core/sqlz" + "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/core/tablefq" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/drivertype" ) // pipeline is used to execute a SLQ query, @@ -184,7 +186,9 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { if handle == "" { if src = p.qc.Collection.Active(); src == nil { log.Debug("No active source, will use scratchdb.") - p.targetPool, err = p.qc.ScratchPoolOpener.OpenScratch(ctx, "scratch") + // REVISIT: ScratchPoolOpener needs a source, so we just make one up. + ephemeralSrc := &source.Source{Type: drivertype.None, Handle: "@scratch" + stringz.Uniq8()} + p.targetPool, err = p.qc.ScratchPoolOpener.OpenScratchFor(ctx, ephemeralSrc) if err != nil { return err } diff --git a/libsq/source/files.go b/libsq/source/files.go index 3c5a66cee..aba4e21ed 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -21,6 +21,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/libsq/source/fetcher" ) @@ -163,6 +164,23 @@ func (fs *Files) addFile(f *os.File, key string) (fscache.ReadAtCloser, error) { return r, nil } +// Filepath returns the file path of src.Location. +func (fs *Files) Filepath(src *Source) (string, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + + // cache miss + f, err := fs.openLocation(src.Location) + if err != nil { + return "", err + } + + if err = f.Close(); err != nil { + return "", errz.Err(err) + } + return f.Name(), nil +} + // Open returns a new io.ReadCloser for src.Location. // If src.Handle is StdinHandle, AddStdin must first have // been invoked. The caller must close the reader. @@ -513,31 +531,45 @@ func httpURL(s string) (u *url.URL, ok bool) { return u, true } -// TempDirFile creates a new temporary file in a new temp dir, -// opens the file for reading and writing, and then closes it. -// It's probably unnecessary to go through the ceremony of -// opening and closing the file, but maybe it's better to fail early. -// It is the caller's responsibility to remove the file and/or dir -// if desired. -func TempDirFile(filename string) (dir, file string, err error) { - dir, err = os.MkdirTemp("", "sq_") - if err != nil { - return "", "", errz.Err(err) +// CacheDirFor gets the cache dir for handle, creating it if necessary. +// If handle is empty or invalid, a random value is generated. +func CacheDirFor(src *Source) (dir string, err error) { + handle := src.Handle + switch handle { + case "": + handle = "@cache_" + stringz.UniqN(32) + case StdinHandle: + // stdin is different input every time, so we need a unique + // cache dir. + handle += "_" + stringz.UniqN(32) + default: + if err = ValidHandle(handle); err != nil { + return "", errz.Wrapf(err, "open cache dir: invalid handle: %s", handle) + } } - file = filepath.Join(dir, filename) - var f *os.File - if f, err = os.OpenFile(file, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o600); err != nil { - // Silently delete the temp dir - _ = os.RemoveAll(dir) - return "", "", errz.Err(err) + dir = CacheDirPath() + sanitized := Handle2SafePath(handle) + hash := src.Hash() + dir = filepath.Join(dir, "sources", sanitized, hash) + if err = os.MkdirAll(dir, 0o750); err != nil { + return "", errz.Wrapf(err, "open cache dir: %s", dir) } - if err = f.Close(); err != nil { - // Silently delete the temp dir - _ = os.RemoveAll(dir) - return "", "", errz.Wrap(err, "close temp file") - } + return dir, nil +} - return dir, file, nil +// CacheDirPath returns the sq cache dir. This is generally +// in USER_CACHE_DIR/sq/cache, but could also be in TEMP_DIR/sq/cache +// or similar. It is not guaranteed that the returned dir exists +// or is accessible. +func CacheDirPath() (dir string) { + var err error + if dir, err = os.UserCacheDir(); err != nil { + // Some systems may not have a user cache dir, so we fall back + // to the system temp dir. + dir = os.TempDir() + } + dir = filepath.Join(dir, "sq", "cache") + return dir } diff --git a/libsq/source/handle.go b/libsq/source/handle.go index d4a02d058..21cef7aa3 100644 --- a/libsq/source/handle.go +++ b/libsq/source/handle.go @@ -254,3 +254,9 @@ func Contains[S *Source | ~string](srcs []*Source, s S) bool { return false } + +// Handle2SafePath returns a string derived from handle that +// is safe to use as a file path. +func Handle2SafePath(handle string) string { + return strings.ReplaceAll(strings.TrimPrefix(handle, "@"), "/", "__") +} diff --git a/libsq/source/source.go b/libsq/source/source.go index 1f4723be2..adc6a522b 100644 --- a/libsq/source/source.go +++ b/libsq/source/source.go @@ -2,6 +2,8 @@ package source import ( + "bytes" + "crypto/sha256" "fmt" "log/slog" "net/url" @@ -91,6 +93,24 @@ type Source struct { Options options.Options `yaml:"options,omitempty" json:"options,omitempty"` } +// Hash returns an SHA256 hash of all fields of s. The Source.Options +// field is ignored. If s is nil, the empty string is returned. +func (s *Source) Hash() string { + if s == nil { + return "" + } + + buf := bytes.Buffer{} + buf.WriteString(s.Handle) + buf.WriteString(string(s.Type)) + buf.WriteString(s.Location) + buf.WriteString(s.Catalog) + buf.WriteString(s.Schema) + buf.WriteString(s.Options.Hash()) + sum := sha256.Sum256(buf.Bytes()) + return fmt.Sprintf("%x", sum) +} + // LogValue implements slog.LogValuer. func (s *Source) LogValue() slog.Value { if s == nil { diff --git a/testh/testh.go b/testh/testh.go index 63c288876..9e18d9c33 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -169,7 +169,7 @@ func (h *Helper) init() { h.files.AddDriverDetectors(source.DetectMagicNumber) - h.pools = driver.NewPools(log, h.registry, sqlite3.NewScratchSource) + h.pools = driver.NewPools(log, h.registry, h.files, sqlite3.NewScratchSource) h.Cleanup.AddC(h.pools) h.registry.AddProvider(sqlite3.Type, &sqlite3.Provider{Log: log}) From 61def8f3b9891c90514c6f47210d1c4ada10110c Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 10:24:57 -0700 Subject: [PATCH 003/195] wip: broken --- drivers/csv/csv.go | 37 +-- drivers/csv/detect_type.go | 5 +- drivers/sqlite3/sqlite3.go | 7 + go.mod | 1 + go.sum | 2 + libsq/core/ioz/checksum.go | 24 +- libsq/core/options/options.go | 9 +- libsq/driver/driver.go | 347 --------------------------- libsq/driver/sources.go | 431 ++++++++++++++++++++++++++++++++++ libsq/source/files.go | 72 +++++- 10 files changed, 531 insertions(+), 404 deletions(-) create mode 100644 libsq/driver/sources.go diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index 60a3d2f00..e0eac4dac 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -77,7 +77,6 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er } allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) - // impl, err := d.scratcher.OpenScratchFor(ctx, src, allowCache) ingestFn := func(ctx context.Context, destPool driver.Pool) error { return ingestCSV(ctx, src, d.files.OpenFunc(src), destPool) @@ -89,39 +88,6 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er } p.impl = backingPool - - // - //var err error - //if { - // // Caching is enabled, let's see if we can find a cached copy. - // var foundCached bool - // p.impl, foundCached, err = d.scratcher.OpenCachedFor(ctx, src) - // if err != nil { - // return nil, err - // } - // if foundCached { - // log.Debug("Cache HIT: found cached copy of source", - // lga.Src, src, "cached", p.impl.Source(), - // ) - // return p, nil - // } - // - // log.Debug("Cache MISS: no cache for source", lga.Src, src) - //} - // - //if p.impl == nil { - // p.impl, err = d.scratcher.OpenScratchFor(ctx, src) - // if err != nil { - // return nil, err - // } - //} - // - //if err = ingestCSV(ctx, src, d.files.OpenFunc(src), p.impl); err != nil { - // return nil, err - //} - - // FIXME: We really should be writing the checksum after ingestCSV happens - return p, nil } @@ -141,6 +107,9 @@ func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { // Ping implements driver.Driver. func (d *driveri) Ping(_ context.Context, src *source.Source) error { + // FIXME: Does Ping calling d.files.Open cause a full read? + // We probably just want to check that the file exists + // or is accessible. r, err := d.files.Open(src) if err != nil { return err diff --git a/drivers/csv/detect_type.go b/drivers/csv/detect_type.go index 32b17ad49..1acbf730d 100644 --- a/drivers/csv/detect_type.go +++ b/drivers/csv/detect_type.go @@ -70,9 +70,8 @@ const ( scoreYes float32 = 0.9 ) -// isCSV returns a score indicating the -// the confidence that cr is reading legitimate CSV, where -// a score <= 0 is not CSV, a score >= 1 is definitely CSV. +// isCSV returns a score indicating the confidence that cr is reading +// legitimate CSV, where a score <= 0 is not CSV, a score >= 1 is definitely CSV. func isCSV(ctx context.Context, cr *csv.Reader) (score float32) { const ( maxRecords int = 100 diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index 22ea8c7b9..b222211ca 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -145,6 +145,13 @@ func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, erro if err != nil { return nil, err } + + if strings.Contains(fp, "checksum") { + x := true + _ = x + + } + db, err := sql.Open(dbDrvr, fp) if err != nil { return nil, errz.Wrapf(errw(err), "failed to open sqlite3 source with DSN: %s", fp) diff --git a/go.mod b/go.mod index ef26e82fa..a57718480 100644 --- a/go.mod +++ b/go.mod @@ -68,6 +68,7 @@ require ( github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/muesli/mango v0.2.0 // indirect github.com/muesli/mango-pflag v0.1.0 // indirect + github.com/nightlyone/lockfile v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/richardlehane/mscfb v1.0.4 // indirect github.com/richardlehane/msoleps v1.0.3 // indirect diff --git a/go.sum b/go.sum index 85a26d4c4..5e9716077 100644 --- a/go.sum +++ b/go.sum @@ -126,6 +126,8 @@ github.com/neilotoole/shelleditor v0.4.1 h1:74LEw2mVo3jtNw2BjII6RSss9DXgEqAbmCQD github.com/neilotoole/shelleditor v0.4.1/go.mod h1:QanOZN4syDMp/L0SKwZb47Mh49mvLWX3ja5YfbYDDjo= github.com/neilotoole/slogt v1.1.0 h1:c7qE92sq+V0yvCuaxph+RQ2jOKL61c4hqS1Bv9W7FZE= github.com/neilotoole/slogt v1.1.0/go.mod h1:RCrGXkPc/hYybNulqQrMHRtvlQ7F6NktNVLuLwk6V+w= +github.com/nightlyone/lockfile v1.0.0 h1:RHep2cFKK4PonZJDdEl4GmkabuhbsRMgk/k3uAmxBiA= +github.com/nightlyone/lockfile v1.0.0/go.mod h1:rywoIealpdNse2r832aiD9jRk8ErCatROs6LzC841CI= github.com/otiai10/copy v1.14.0 h1:dCI/t1iTdYGtkvCuBG2BgR6KZa83PTclw4U5n2wAllU= github.com/otiai10/copy v1.14.0/go.mod h1:ECfuL02W+/FkTWZWgQqXPWZgW9oeKCSQ5qVfSc4qc4w= github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks= diff --git a/libsq/core/ioz/checksum.go b/libsq/core/ioz/checksum.go index 35ae17d95..7e170eadb 100644 --- a/libsq/core/ioz/checksum.go +++ b/libsq/core/ioz/checksum.go @@ -13,10 +13,13 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" ) +// Checksum is a checksum of a file. +type Checksum string + // FileChecksum returns a checksum of the file at path. // The checksum is based on the file's name, size, mode, and // modification time. File contents are not read. -func FileChecksum(path string) (string, error) { +func FileChecksum(path string) (Checksum, error) { fi, err := os.Stat(path) if err != nil { return "", errz.Wrap(err, "calculate file checksum") @@ -30,7 +33,7 @@ func FileChecksum(path string) (string, error) { buf.WriteString(strconv.FormatBool(fi.IsDir())) sum := sha256.Sum256(buf.Bytes()) - return fmt.Sprintf("%x", sum), nil + return Checksum(fmt.Sprintf("%x", sum)), nil } // WriteChecksum appends a checksum line to w, including @@ -41,7 +44,7 @@ func FileChecksum(path string) (string, error) { // // Use FileChecksum to calculate a checksum, and ReadChecksums // to read this format. -func WriteChecksum(w io.Writer, sum, name string) error { +func WriteChecksum(w io.Writer, sum Checksum, name string) error { _, err := fmt.Fprintf(w, "%s %s\n", sum, name) return errz.Err(err) } @@ -50,7 +53,7 @@ func WriteChecksum(w io.Writer, sum, name string) error { // the previous contents. // // See: WriteChecksum. -func WriteChecksumFile(path, sum, name string) error { +func WriteChecksumFile(path string, sum Checksum, name string) error { f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) if err != nil { return errz.Wrap(err, "write checksum file") @@ -62,7 +65,7 @@ func WriteChecksumFile(path, sum, name string) error { // ReadChecksumsFile reads a checksum file from path. // // See ReadChecksums for details. -func ReadChecksumsFile(path string) (map[string]string, error) { +func ReadChecksumsFile(path string) (map[string]Checksum, error) { f, err := os.Open(path) if err != nil { return nil, errz.Err(err) @@ -77,8 +80,8 @@ func ReadChecksumsFile(path string) (map[string]string, error) { // of checksums keyed by name. Empty lines, and lines beginning // with "#" (comments) are ignored. This function is the // inverse of WriteChecksum. -func ReadChecksums(r io.Reader) (map[string]string, error) { - sums := map[string]string{} +func ReadChecksums(r io.Reader) (map[string]Checksum, error) { + sums := map[string]Checksum{} sc := bufio.NewScanner(r) for sc.Scan() { @@ -87,12 +90,17 @@ func ReadChecksums(r io.Reader) (map[string]string, error) { continue } + if strings.Contains(line, "INTEGER") { // FIXME: delete + x := true + _ = x + } + parts := strings.SplitN(line, " ", 2) if len(parts) != 2 { return nil, errz.Errorf("invalid checksum line: %q", line) } - sums[parts[1]] = parts[0] + sums[parts[1]] = Checksum(parts[0]) } return sums, errz.Wrap(sc.Err(), "read checksums") diff --git a/libsq/core/options/options.go b/libsq/core/options/options.go index 5dec9ad2d..597c95576 100644 --- a/libsq/core/options/options.go +++ b/libsq/core/options/options.go @@ -186,12 +186,13 @@ func (o Options) Hash() string { return "" } + keys := o.Keys() buf := bytes.Buffer{} - for k, v := range o { + for _, k := range keys { buf.WriteString(k) - if v != nil { - buf.WriteString(fmt.Sprintf("%v", v)) - } + v := o[k] + buf.WriteString(fmt.Sprintf("%v", v)) + } sum := sha256.Sum256(buf.Bytes()) return fmt.Sprintf("%x", sum) diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index 8cc8e5883..cea51c14f 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -2,19 +2,11 @@ package driver import ( "context" - "crypto/sha256" "database/sql" - "fmt" - "log/slog" - "path/filepath" - "strings" - "sync" "time" "github.com/neilotoole/sq/libsq/ast/render" - "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" @@ -418,322 +410,6 @@ type Metadata struct { DefaultPort int `json:"default_port" yaml:"default_port"` } -var ( - _ PoolOpener = (*Pools)(nil) - _ JoinPoolOpener = (*Pools)(nil) - _ ScratchPoolOpener = (*Pools)(nil) -) - -// ScratchSrcFunc is a function that returns a scratch source. -// The caller is responsible for invoking cleanFn. -type ScratchSrcFunc func(ctx context.Context, name string) (src *source.Source, cleanFn func() error, err error) - -// Pools provides a mechanism for getting Pool instances. -// Note that at this time instances returned by Open are cached -// and then closed by Close. This may be a bad approach. -// -// FIXME: Why not rename driver.Pools to driver.Sources? -type Pools struct { - log *slog.Logger - drvrs Provider - mu sync.Mutex - scratchSrcFn ScratchSrcFunc - files *source.Files - pools map[string]Pool - clnup *cleanup.Cleanup -} - -// NewPools returns a Pools instances. -func NewPools(log *slog.Logger, drvrs Provider, - files *source.Files, scratchSrcFn ScratchSrcFunc, -) *Pools { - return &Pools{ - log: log, - drvrs: drvrs, - mu: sync.Mutex{}, - scratchSrcFn: scratchSrcFn, - files: files, - pools: map[string]Pool{}, - clnup: cleanup.New(), - } -} - -// Open returns an opened Pool for src. The returned Pool -// may be cached and returned on future invocations for the -// same source (where each source fields is identical). -// Thus, the caller should typically not close -// the Pool: it will be closed via d.Close. -// -// NOTE: This entire logic re caching/not-closing is a bit sketchy, -// and needs to be revisited. -// -// Open implements PoolOpener. -func (d *Pools) Open(ctx context.Context, src *source.Source) (Pool, error) { - lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) - - d.mu.Lock() - defer d.mu.Unlock() - - key := src.Handle + "_" + hashSource(src) - - pool, ok := d.pools[key] - if ok { - return pool, nil - } - - drvr, err := d.drvrs.DriverFor(src.Type) - if err != nil { - return nil, err - } - - baseOptions := options.FromContext(ctx) - o := options.Merge(baseOptions, src.Options) - - ctx = options.NewContext(ctx, o) - pool, err = drvr.Open(ctx, src) - if err != nil { - return nil, err - } - - d.clnup.AddC(pool) - - d.pools[key] = pool - return pool, nil -} - -// OpenScratchFor returns a scratch database instance. It is not -// necessary for the caller to close the returned Pool as -// its Close method will be invoked by d.Close. -// -// OpenScratchFor implements ScratchPoolOpener. -// -// REVISIT: do we really need to pass a source here? Just a string should do. -// -// FIXME: the problem is with passing src? -// -// FIXME: Add cacheAllowed bool? -func (d *Pools) OpenScratchFor(ctx context.Context, src *source.Source) (Pool, error) { - const msgCloseScratch = "Close scratch db" - - _, _, srcCacheFilepath, err := d.getCachePaths(src) - if err != nil { - return nil, err - } - - scratchSrc, cleanFn, err := d.scratchSrcFn(ctx, srcCacheFilepath) - if err != nil { - // if err is non-nil, cleanup is guaranteed to be nil - return nil, err - } - d.log.Debug("Opening scratch src", lga.Src, scratchSrc) - - backingDrvr, err := d.drvrs.DriverFor(scratchSrc.Type) - if err != nil { - lg.WarnIfFuncError(d.log, msgCloseScratch, cleanFn) - return nil, err - } - - var backingPool Pool - backingPool, err = backingDrvr.Open(ctx, scratchSrc) - if err != nil { - lg.WarnIfFuncError(d.log, msgCloseScratch, cleanFn) - return nil, err - } - - allowCache := OptIngestCache.Get(options.FromContext(ctx)) - if !allowCache { - // If the ingest cache is disabled, we add the cleanup func - // so the scratch DB is deleted when the session ends. - d.clnup.AddE(cleanFn) - } - - return backingPool, nil -} - -// OpenIngest implements driver.ScratchPoolOpener. -func (d *Pools) OpenIngest(ctx context.Context, src *source.Source, - ingestFn func(ctx context.Context, destPool Pool) error, allowCache bool, -) (Pool, error) { - log := lg.FromContext(ctx) - - ingestFilePath, err := d.files.Filepath(src) - if err != nil { - return nil, err - } - - var checksumsPath string - var cacheDir string - if cacheDir, _, checksumsPath, err = d.getCachePaths(src); err != nil { - return nil, err - } - - log.Debug("Using cache dir", lga.Path, cacheDir) - - if allowCache { - var ( - impl Pool - foundCached bool - ) - if impl, foundCached, err = d.OpenCachedFor(ctx, src); err != nil { - return nil, err - } - if foundCached { - log.Debug("Ingest cache HIT: found cached copy of source", - lga.Src, src, "cached", impl.Source(), - ) - return impl, nil - } - - log.Debug("Ingest cache MISS: no cache for source", lga.Src, src) - } - - impl, err := d.OpenScratchFor(ctx, src) - if err != nil { - return nil, err - } - - start := time.Now() - err = ingestFn(ctx, impl) - elapsed := time.Since(start) - - if err != nil { - log.Error("Ingest failed", - lga.Src, src, lga.Dest, impl.Source(), - lga.Elapsed, elapsed, lga.Err, err, - ) - return nil, err - } - - log.Debug("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) - - if allowCache { - // Write the checksums file. - var sum string - if sum, err = ioz.FileChecksum(ingestFilePath); err != nil { - log.Warn("Failed to compute checksum for source file; caching not in effect", - lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) - return impl, nil - } - - if err = ioz.WriteChecksumFile(checksumsPath, sum, ingestFilePath); err != nil { - log.Warn("Failed to write checksum; file caching not in effect", - lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) - } - } - - return impl, nil -} - -// getCachePaths returns the paths to the cache files for src. -// There is no guarantee that these files exist, or are accessible. -// It's just the paths. -func (d *Pools) getCachePaths(src *source.Source) (dir, cacheDB, checksums string, err error) { //nolint:unparam - if dir, err = source.CacheDirFor(src); err != nil { - return "", "", "", err - } - - checksums = filepath.Join(dir, "checksums.txt") - cacheDB = filepath.Join(dir, "cached.db") - return dir, cacheDB, checksums, nil -} - -// OpenCachedFor implements ScratchPoolOpener. -func (d *Pools) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bool, error) { - _, cacheDBPath, checksumsPath, err := d.getCachePaths(src) - if err != nil { - return nil, false, err - } - - if !ioz.FileAccessible(checksumsPath) { - return nil, false, nil - } - - mChecksums, err := ioz.ReadChecksumsFile(checksumsPath) - if err != nil { - return nil, false, err - } - - drvr, err := d.drvrs.DriverFor(src.Type) - if err != nil { - return nil, false, err - } - - if drvr.DriverMetadata().IsSQL { - return nil, false, errz.Errorf("open file cache for source %s: driver {%s} is SQL, not document", - src.Handle, src.Type) - } - - srcFilepath, err := d.files.Filepath(src) - if err != nil { - return nil, false, err - } - - cachedChecksum, ok := mChecksums[srcFilepath] - if !ok { - return nil, false, nil - } - - srcChecksum, err := ioz.FileChecksum(srcFilepath) - if err != nil { - return nil, false, err - } - - if srcChecksum != cachedChecksum { - return nil, false, nil - } - - // The checksums match, so we can use the cached DB, - // if it exists. - if !ioz.FileAccessible(cacheDBPath) { - return nil, false, nil - } - - backingType, err := d.files.DriverType(ctx, cacheDBPath) - if err != nil { - return nil, false, err - } - - backingSrc := &source.Source{ - Handle: src.Handle + "_cached", - Location: cacheDBPath, - Type: backingType, - } - - backingPool, err := d.Open(ctx, backingSrc) - if err != nil { - return nil, false, errz.Wrapf(err, "open cached DB for source %s", src.Handle) - } - - return backingPool, true, nil -} - -// OpenJoin opens an appropriate database for use as -// a work DB for joining across sources. -// -// Note: There is much work to be done on this method. At this time, only -// two sources are supported. Ultimately OpenJoin should be able to -// inspect the join srcs and use heuristics to determine the best -// location for the join to occur (to minimize copying of data for -// the join etc.). Currently the implementation simply delegates -// to OpenScratch. -// -// OpenJoin implements JoinPoolOpener. -func (d *Pools) OpenJoin(ctx context.Context, srcs ...*source.Source) (Pool, error) { - var names []string - for _, src := range srcs { - names = append(names, src.Handle[1:]) - } - - d.log.Debug("OpenJoin", "sources", strings.Join(names, ",")) - return d.OpenScratchFor(ctx, srcs[0]) -} - -// Close closes d, invoking Close on any instances opened via d.Open. -func (d *Pools) Close() error { - d.log.Debug("Closing databases(s)...", lga.Count, d.clnup.Len()) - return d.clnup.Run() -} - // OpeningPing is a standardized mechanism to ping db using // driver.OptConnOpenTimeout. This should be invoked by each SQL // driver impl in its Open method. If the ping fails, db is closed. @@ -753,26 +429,3 @@ func OpeningPing(ctx context.Context, src *source.Source, db *sql.DB) error { return nil } - -// hashSource computes a hash for src. If src is nil, empty string is returned. -func hashSource(src *source.Source) string { - if src == nil { - return "" - } - - h := sha256.New() - h.Write([]byte(src.Handle)) - h.Write([]byte(src.Location)) - h.Write([]byte(src.Type)) - - if len(src.Options) > 0 { - keys := src.Options.Keys() - for _, k := range keys { - v := src.Options[k] - h.Write([]byte(fmt.Sprintf("%s:%v", k, v))) - } - } - - b := h.Sum(nil) - return string(b) -} diff --git a/libsq/driver/sources.go b/libsq/driver/sources.go new file mode 100644 index 000000000..b90bbcd0e --- /dev/null +++ b/libsq/driver/sources.go @@ -0,0 +1,431 @@ +package driver + +import ( + "context" + "errors" + "log/slog" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/neilotoole/sq/libsq/core/cleanup" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/retry" + "github.com/neilotoole/sq/libsq/source" + "github.com/nightlyone/lockfile" +) + +var ( + _ PoolOpener = (*Pools)(nil) + _ JoinPoolOpener = (*Pools)(nil) + _ ScratchPoolOpener = (*Pools)(nil) +) + +// ScratchSrcFunc is a function that returns a scratch source. +// The caller is responsible for invoking cleanFn. +type ScratchSrcFunc func(ctx context.Context, name string) (src *source.Source, cleanFn func() error, err error) + +// Pools provides a mechanism for getting Pool instances. +// Note that at this time instances returned by Open are cached +// and then closed by Close. This may be a bad approach. +// +// FIXME: Why not rename driver.Pools to driver.Sources? +type Pools struct { + log *slog.Logger + drvrs Provider + mu sync.Mutex + scratchSrcFn ScratchSrcFunc + files *source.Files + pools map[string]Pool + clnup *cleanup.Cleanup +} + +// NewPools returns a Pools instances. +func NewPools(log *slog.Logger, drvrs Provider, + files *source.Files, scratchSrcFn ScratchSrcFunc, +) *Pools { + return &Pools{ + log: log, + drvrs: drvrs, + mu: sync.Mutex{}, + scratchSrcFn: scratchSrcFn, + files: files, + pools: map[string]Pool{}, + clnup: cleanup.New(), + } +} + +// Open returns an opened Pool for src. The returned Pool +// may be cached and returned on future invocations for the +// same source (where each source fields is identical). +// Thus, the caller should typically not close +// the Pool: it will be closed via d.Close. +// +// NOTE: This entire logic re caching/not-closing is a bit sketchy, +// and needs to be revisited. +// +// Open implements PoolOpener. +func (d *Pools) Open(ctx context.Context, src *source.Source) (Pool, error) { + lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) + d.mu.Lock() + defer d.mu.Unlock() + return d.doOpen(ctx, src) +} + +func (d *Pools) doOpen(ctx context.Context, src *source.Source) (Pool, error) { + lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) + key := src.Handle + "_" + src.Hash() + + pool, ok := d.pools[key] + if ok { + return pool, nil + } + + drvr, err := d.drvrs.DriverFor(src.Type) + if err != nil { + return nil, err + } + + baseOptions := options.FromContext(ctx) + o := options.Merge(baseOptions, src.Options) + + ctx = options.NewContext(ctx, o) + pool, err = drvr.Open(ctx, src) + if err != nil { + return nil, err + } + + d.clnup.AddC(pool) + + d.pools[key] = pool + return pool, nil +} + +// OpenScratchFor returns a scratch database instance. It is not +// necessary for the caller to close the returned Pool as +// its Close method will be invoked by d.Close. +// +// OpenScratchFor implements ScratchPoolOpener. +// +// REVISIT: do we really need to pass a source here? Just a string should do. +// +// FIXME: the problem is with passing src? +// +// FIXME: Add cacheAllowed bool? +func (d *Pools) OpenScratchFor(ctx context.Context, src *source.Source) (Pool, error) { + const msgCloseScratch = "Close scratch db" + + _, srcCacheDBFilepath, _, err := d.getCachePaths(src) + if err != nil { + return nil, err + } + + scratchSrc, cleanFn, err := d.scratchSrcFn(ctx, srcCacheDBFilepath) + if err != nil { + // if err is non-nil, cleanup is guaranteed to be nil + return nil, err + } + d.log.Debug("Opening scratch src", lga.Src, scratchSrc) + + backingDrvr, err := d.drvrs.DriverFor(scratchSrc.Type) + if err != nil { + lg.WarnIfFuncError(d.log, msgCloseScratch, cleanFn) + return nil, err + } + + var backingPool Pool + backingPool, err = backingDrvr.Open(ctx, scratchSrc) + if err != nil { + lg.WarnIfFuncError(d.log, msgCloseScratch, cleanFn) + return nil, err + } + + allowCache := OptIngestCache.Get(options.FromContext(ctx)) + if !allowCache { + // If the ingest cache is disabled, we add the cleanup func + // so the scratch DB is deleted when the session ends. + d.clnup.AddE(cleanFn) + } + + return backingPool, nil +} + +// OpenIngest implements driver.ScratchPoolOpener. +func (d *Pools) OpenIngest(ctx context.Context, src *source.Source, + ingestFn func(ctx context.Context, destPool Pool) error, allowCache bool, +) (Pool, error) { + if !allowCache || src.Handle == source.StdinHandle { + // We don't currently cache stdin. + return d.openIngestNoCache(ctx, src, ingestFn) + } + + return d.openIngestCache(ctx, src, ingestFn) +} + +func (d *Pools) openIngestNoCache(ctx context.Context, src *source.Source, + ingestFn func(ctx context.Context, destPool Pool) error, +) (Pool, error) { + log := lg.FromContext(ctx) + impl, err := d.OpenScratchFor(ctx, src) + if err != nil { + return nil, err + } + + start := time.Now() + err = ingestFn(ctx, impl) + elapsed := time.Since(start) + + if err != nil { + log.Error("Ingest failed", + lga.Src, src, lga.Dest, impl.Source(), + lga.Elapsed, elapsed, lga.Err, err, + ) + } + + d.log.Debug("Ingest completed", + lga.Src, src, lga.Dest, impl.Source(), + lga.Elapsed, elapsed) + return impl, nil +} + +func (d *Pools) openIngestCache(ctx context.Context, src *source.Source, + ingestFn func(ctx context.Context, destPool Pool) error, +) (Pool, error) { + log := lg.FromContext(ctx) + + lock, err := d.acquireLock(ctx, src) + if err != nil { + return nil, err + } + defer func() { + log.Debug("About to release cache lock...", "lock", lock) + if err := lock.Unlock(); err != nil { + log.Warn("Failed to release cache lock", "lock", lock, lga.Err, err) + } else { + log.Debug("Released cache lock", "lock", lock) + } + }() + + cacheDir, _, checksumsPath, err := d.getCachePaths(src) + if err != nil { + return nil, err + } + + log.Debug("Using cache dir", lga.Path, cacheDir) + + ingestFilePath, err := d.files.Filepath(ctx, src) + if err != nil { + return nil, err + } + + var ( + impl Pool + foundCached bool + ) + if impl, foundCached, err = d.OpenCachedFor(ctx, src); err != nil { + return nil, err + } + if foundCached { + log.Debug("Ingest cache HIT: found cached copy of source", + lga.Src, src, "cached", impl.Source(), + ) + return impl, nil + } + + log.Debug("Ingest cache MISS: no cache for source", lga.Src, src) + + impl, err = d.OpenScratchFor(ctx, src) + if err != nil { + return nil, err + } + + start := time.Now() + err = ingestFn(ctx, impl) + elapsed := time.Since(start) + + if err != nil { + log.Error("Ingest failed", + lga.Src, src, lga.Dest, impl.Source(), + lga.Elapsed, elapsed, lga.Err, err, + ) + return nil, err + } + + log.Debug("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) + + // Write the checksums file. + var sum ioz.Checksum + if sum, err = ioz.FileChecksum(ingestFilePath); err != nil { + log.Warn("Failed to compute checksum for source file; caching not in effect", + lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) + return impl, nil + } + + if err = ioz.WriteChecksumFile(checksumsPath, sum, ingestFilePath); err != nil { + log.Warn("Failed to write checksum; file caching not in effect", + lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) + } + + return impl, nil +} + +// getCachePaths returns the paths to the cache files for src. +// There is no guarantee that these files exist, or are accessible. +// It's just the paths. +func (d *Pools) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checksums string, err error) { //nolint:unparam + if srcCacheDir, err = source.CacheDirFor(src); err != nil { + return "", "", "", err + } + + checksums = filepath.Join(srcCacheDir, "checksums.txt") + cacheDB = filepath.Join(srcCacheDir, "cached.db") + return srcCacheDir, cacheDB, checksums, nil +} + +// acquireLock acquires a lock for src. The caller +// is responsible for unlocking the lock, e.g.: +// +// defer lg.WarnIfFuncError(d.log, "failed to unlock cache lock", lock.Unlock) +// +// The lock acquisition process is retried with backoff. +func (d *Pools) acquireLock(ctx context.Context, src *source.Source) (lockfile.Lockfile, error) { + lock, err := d.getLockfileFor(src) + if err != nil { + return "", err + } + + err = retry.Do(ctx, time.Second*5, + lock.TryLock, + func(err error) bool { + var temporaryError lockfile.TemporaryError + return errors.As(err, &temporaryError) + }, + ) + if err != nil { + return "", errz.Wrap(err, "failed to get lock") + } + + lg.FromContext(ctx).Debug("Acquired cache lock", "lock", lock) + return lock, nil +} + +// getLockfileFor returns a lockfile for src. It doesn't +// actually acquire the lock. +func (d *Pools) getLockfileFor(src *source.Source) (lockfile.Lockfile, error) { + srcCacheDir, _, _, err := d.getCachePaths(src) + if err != nil { + return "", err + } + + if err = os.MkdirAll(srcCacheDir, 0o750); err != nil { + return "", errz.Err(err) + } + lockPath := filepath.Join(srcCacheDir, "pid.lock") + return lockfile.New(lockPath) +} + +// OpenCachedFor implements ScratchPoolOpener. +func (d *Pools) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bool, error) { + _, cacheDBPath, checksumsPath, err := d.getCachePaths(src) + if err != nil { + return nil, false, err + } + + if !ioz.FileAccessible(checksumsPath) { + return nil, false, nil + } + + mChecksums, err := ioz.ReadChecksumsFile(checksumsPath) + if err != nil { + return nil, false, err + } + + drvr, err := d.drvrs.DriverFor(src.Type) + if err != nil { + return nil, false, err + } + + if drvr.DriverMetadata().IsSQL { + return nil, false, errz.Errorf("open file cache for source %s: driver {%s} is SQL, not document", + src.Handle, src.Type) + } + + srcFilepath, err := d.files.Filepath(ctx, src) + if err != nil { + return nil, false, err + } + d.log.Debug("Got srcFilepath for src", + lga.Src, src, lga.Path, srcFilepath) + + cachedChecksum, ok := mChecksums[srcFilepath] + if !ok { + return nil, false, nil + } + + srcChecksum, err := ioz.FileChecksum(srcFilepath) + if err != nil { + return nil, false, err + } + + if srcChecksum != cachedChecksum { + return nil, false, nil + } + + // The checksums match, so we can use the cached DB, + // if it exists. + if !ioz.FileAccessible(cacheDBPath) { + return nil, false, nil + } + + backingType, err := d.files.DriverType(ctx, cacheDBPath) + if err != nil { + return nil, false, err + } + + backingSrc := &source.Source{ + Handle: src.Handle + "_cached", + Location: "sqlite3://" + cacheDBPath, + Type: backingType, + } + + backingPool, err := d.doOpen(ctx, backingSrc) + if err != nil { + return nil, false, errz.Wrapf(err, "open cached DB for source %s", src.Handle) + } + + return backingPool, true, nil +} + +// OpenJoin opens an appropriate database for use as +// a work DB for joining across sources. +// +// Note: There is much work to be done on this method. At this time, only +// two sources are supported. Ultimately OpenJoin should be able to +// inspect the join srcs and use heuristics to determine the best +// location for the join to occur (to minimize copying of data for +// the join etc.). Currently the implementation simply delegates +// to OpenScratch. +// +// OpenJoin implements JoinPoolOpener. +func (d *Pools) OpenJoin(ctx context.Context, srcs ...*source.Source) (Pool, error) { + var names []string + for _, src := range srcs { + names = append(names, src.Handle[1:]) + } + + d.log.Debug("OpenJoin", "sources", strings.Join(names, ",")) + return d.OpenScratchFor(ctx, srcs[0]) +} + +// Close closes d, invoking Close on any instances opened via d.Open. +func (d *Pools) Close() error { + d.log.Debug("Closing databases(s)...", lga.Count, d.clnup.Len()) + return d.clnup.Run() +} diff --git a/libsq/source/files.go b/libsq/source/files.go index aba4e21ed..5557656ce 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -165,20 +165,76 @@ func (fs *Files) addFile(f *os.File, key string) (fscache.ReadAtCloser, error) { } // Filepath returns the file path of src.Location. -func (fs *Files) Filepath(src *Source) (string, error) { +// An error is returned the source's driver type +// is not a file type (i.e. it is a SQL driver). +func (fs *Files) Filepath(_ context.Context, src *Source) (string, error) { fs.mu.Lock() defer fs.mu.Unlock() + loc := src.Location - // cache miss - f, err := fs.openLocation(src.Location) - if err != nil { - return "", err + if fp, ok := isFpath(loc); ok { + return fp, nil } - if err = f.Close(); err != nil { - return "", errz.Err(err) + u, ok := httpURL(loc) + if !ok { + return "", errz.Errorf("not a valid file location: %s", loc) } - return f.Name(), nil + + _ = u + // It's a remote file. We really should download it here. + // FIXME: implement downloading. + return "", errz.Errorf("Filepath not implemented for remote files: %s", loc) + // + //if ; !ok { + // // It's not a filepath, and it's not a http URL, + // // so we need to download it. + // + // + //} + // + //return "", + // + // + // + //typ, err := fs.DriverType(ctx, src.Location) + //if err != nil { + // return "", err + //} + // + //if !fs.fcache.Exists(loc) { + // // cache miss + // f, err := fs.openLocation(loc) + // if err != nil { + // return "", err + // } + // + // // Note that addFile closes f + // _, err = fs.addFile(f, loc) + // if err != nil { + // return "", err + // } + // return f.Name(), nil + //} + // + //return loc, nil + //r, _, err := fs.fcache.Get(loc) + //if err != nil { + // return "", err + //} + // + //return r, nil + + // // cache miss + // f, err := fs.openLocation(src.Location) + // if err != nil { + // return "", err + // } + // + // if err = f.Close(); err != nil { + // return "", errz.Err(err) + // } + // return f.Name(), nil } // Open returns a new io.ReadCloser for src.Location. From c8fbccd0d2e619e188199a3066471846f7339903 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 10:29:34 -0700 Subject: [PATCH 004/195] wip: adding fscache --- go.mod | 2 ++ go.sum | 4 ++++ libsq/core/ioz/fscache | 1 + 3 files changed, 7 insertions(+) create mode 160000 libsq/core/ioz/fscache diff --git a/go.mod b/go.mod index a57718480..ebda24ba9 100644 --- a/go.mod +++ b/go.mod @@ -52,6 +52,8 @@ require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/djherbis/atime v1.1.0 // indirect + github.com/djherbis/stream v1.4.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/huandu/xstrings v1.4.0 // indirect diff --git a/go.sum b/go.sum index 5e9716077..753a9ee1a 100644 --- a/go.sum +++ b/go.sum @@ -32,8 +32,12 @@ github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/djherbis/atime v1.1.0 h1:rgwVbP/5by8BvvjBNrbh64Qz33idKT3pSnMSJsxhi0g= +github.com/djherbis/atime v1.1.0/go.mod h1:28OF6Y8s3NQWwacXc5eZTsEsiMzp7LF8MbXE+XJPdBE= github.com/djherbis/fscache v0.10.1 h1:hDv+RGyvD+UDKyRYuLoVNbuRTnf2SrA2K3VyR1br9lk= github.com/djherbis/fscache v0.10.1/go.mod h1:yyPYtkNnnPXsW+81lAcQS6yab3G2CRfnPLotBvtbf0c= +github.com/djherbis/stream v1.4.0 h1:aVD46WZUiq5kJk55yxJAyw6Kuera6kmC3i2vEQyW/AE= +github.com/djherbis/stream v1.4.0/go.mod h1:cqjC1ZRq3FFwkGmUtHwcldbnW8f0Q4YuVsGW1eAFtOk= github.com/ecnepsnai/osquery v1.0.1 h1:i96n/3uqcafKZtRYmXVNqekKbfrIm66q179mWZ/Y2Aw= github.com/ecnepsnai/osquery v1.0.1/go.mod h1:vxsezNRznmkLa8UjVh88tlJiRbgW7iwinkjyg/Xc2RU= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= diff --git a/libsq/core/ioz/fscache b/libsq/core/ioz/fscache new file mode 160000 index 000000000..2909c9509 --- /dev/null +++ b/libsq/core/ioz/fscache @@ -0,0 +1 @@ +Subproject commit 2909c950912d2b24c4ea99dfd3cf4a7c2bdc38a2 From 7ccd90545ebcb8460bbcf0af3be275a627e61574 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 10:30:56 -0700 Subject: [PATCH 005/195] wip: adding fscache --- libsq/core/ioz/fscache | 1 - 1 file changed, 1 deletion(-) delete mode 160000 libsq/core/ioz/fscache diff --git a/libsq/core/ioz/fscache b/libsq/core/ioz/fscache deleted file mode 160000 index 2909c9509..000000000 --- a/libsq/core/ioz/fscache +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 2909c950912d2b24c4ea99dfd3cf4a7c2bdc38a2 From c9b0106c1166c0e12610644a1e9406817cdfa198 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 10:31:49 -0700 Subject: [PATCH 006/195] wip: adding fscache --- libsq/core/ioz/fscache/LICENSE | 22 + libsq/core/ioz/fscache/README.md | 93 ++++ libsq/core/ioz/fscache/distrib.go | 85 ++++ libsq/core/ioz/fscache/example_test.go | 69 +++ libsq/core/ioz/fscache/fileinfo.go | 52 +++ libsq/core/ioz/fscache/fs.go | 266 ++++++++++++ libsq/core/ioz/fscache/fscache.go | 373 ++++++++++++++++ libsq/core/ioz/fscache/fscache_test.go | 579 +++++++++++++++++++++++++ libsq/core/ioz/fscache/handler.go | 41 ++ libsq/core/ioz/fscache/haunter.go | 92 ++++ libsq/core/ioz/fscache/layers.go | 128 ++++++ libsq/core/ioz/fscache/lruhaunter.go | 137 ++++++ libsq/core/ioz/fscache/memfs.go | 147 +++++++ libsq/core/ioz/fscache/reaper.go | 37 ++ libsq/core/ioz/fscache/server.go | 206 +++++++++ libsq/core/ioz/fscache/stream.go | 72 +++ 16 files changed, 2399 insertions(+) create mode 100644 libsq/core/ioz/fscache/LICENSE create mode 100644 libsq/core/ioz/fscache/README.md create mode 100644 libsq/core/ioz/fscache/distrib.go create mode 100644 libsq/core/ioz/fscache/example_test.go create mode 100644 libsq/core/ioz/fscache/fileinfo.go create mode 100644 libsq/core/ioz/fscache/fs.go create mode 100644 libsq/core/ioz/fscache/fscache.go create mode 100644 libsq/core/ioz/fscache/fscache_test.go create mode 100644 libsq/core/ioz/fscache/handler.go create mode 100644 libsq/core/ioz/fscache/haunter.go create mode 100644 libsq/core/ioz/fscache/layers.go create mode 100644 libsq/core/ioz/fscache/lruhaunter.go create mode 100644 libsq/core/ioz/fscache/memfs.go create mode 100644 libsq/core/ioz/fscache/reaper.go create mode 100644 libsq/core/ioz/fscache/server.go create mode 100644 libsq/core/ioz/fscache/stream.go diff --git a/libsq/core/ioz/fscache/LICENSE b/libsq/core/ioz/fscache/LICENSE new file mode 100644 index 000000000..1e7b7cc09 --- /dev/null +++ b/libsq/core/ioz/fscache/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2015 Dustin H + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/libsq/core/ioz/fscache/README.md b/libsq/core/ioz/fscache/README.md new file mode 100644 index 000000000..78b57ef35 --- /dev/null +++ b/libsq/core/ioz/fscache/README.md @@ -0,0 +1,93 @@ +fscache +========== + +[![GoDoc](https://godoc.org/github.com/djherbis/fscache?status.svg)](https://godoc.org/github.com/djherbis/fscache) +[![Release](https://img.shields.io/github/release/djherbis/fscache.svg)](https://github.com/djherbis/fscache/releases/latest) +[![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE.txt) +[![go test](https://github.com/djherbis/fscache/actions/workflows/go-test.yml/badge.svg)](https://github.com/djherbis/fscache/actions/workflows/go-test.yml) +[![Coverage Status](https://coveralls.io/repos/djherbis/fscache/badge.svg?branch=master)](https://coveralls.io/r/djherbis/fscache?branch=master) +[![Go Report Card](https://goreportcard.com/badge/github.com/djherbis/fscache)](https://goreportcard.com/report/github.com/djherbis/fscache) + +Usage +------------ +Streaming File Cache for #golang + +fscache allows multiple readers to read from a cache while its being written to. [blog post](https://djherbis.github.io/post/fscache/) + +Using the Cache directly: + +```go +package main + +import ( + "io" + "log" + "os" + "time" + + "gopkg.in/djherbis/fscache.v0" +) + +func main() { + + // create the cache, keys expire after 1 hour. + c, err := fscache.New("./cache", 0755, time.Hour) + if err != nil { + log.Fatal(err.Error()) + } + + // wipe the cache when done + defer c.Clean() + + // Get() and it's streams can be called concurrently but just for example: + for i := 0; i < 3; i++ { + r, w, err := c.Get("stream") + if err != nil { + log.Fatal(err.Error()) + } + + if w != nil { // a new stream, write to it. + go func(){ + w.Write([]byte("hello world\n")) + w.Close() + }() + } + + // the stream has started, read from it + io.Copy(os.Stdout, r) + r.Close() + } +} +``` + +A Caching Middle-ware: + +```go +package main + +import( + "net/http" + "time" + + "gopkg.in/djherbis/fscache.v0" +) + +func main(){ + c, err := fscache.New("./cache", 0700, 0) + if err != nil { + log.Fatal(err.Error()) + } + + handler := func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "%v: %s", time.Now(), "hello world") + } + + http.ListenAndServe(":8080", fscache.Handler(c, http.HandlerFunc(handler))) +} +``` + +Installation +------------ +```sh +go get gopkg.in/djherbis/fscache.v0 +``` diff --git a/libsq/core/ioz/fscache/distrib.go b/libsq/core/ioz/fscache/distrib.go new file mode 100644 index 000000000..60994cc58 --- /dev/null +++ b/libsq/core/ioz/fscache/distrib.go @@ -0,0 +1,85 @@ +package fscache + +import ( + "bytes" + "crypto/sha1" + "encoding/binary" + "io" +) + +// Distributor provides a way to partition keys into Caches. +type Distributor interface { + + // GetCache will always return the same Cache for the same key. + GetCache(key string) Cache + + // Clean should wipe all the caches this Distributor manages + Clean() error +} + +// stdDistribution distributes the keyspace evenly. +func stdDistribution(key string, n uint64) uint64 { + h := sha1.New() + io.WriteString(h, key) + buf := bytes.NewBuffer(h.Sum(nil)[:8]) + i, _ := binary.ReadUvarint(buf) + return i % n +} + +// NewDistributor returns a Distributor which evenly distributes the keyspace +// into the passed caches. +func NewDistributor(caches ...Cache) Distributor { + if len(caches) == 0 { + return nil + } + return &distrib{ + distribution: stdDistribution, + caches: caches, + size: uint64(len(caches)), + } +} + +type distrib struct { + distribution func(key string, n uint64) uint64 + caches []Cache + size uint64 +} + +func (d *distrib) GetCache(key string) Cache { + return d.caches[d.distribution(key, d.size)] +} + +// BUG(djherbis): Return an error if cleaning fails +func (d *distrib) Clean() error { + for _, c := range d.caches { + c.Clean() + } + return nil +} + +// NewPartition returns a Cache which uses the Caches defined by the passed Distributor. +func NewPartition(d Distributor) Cache { + return &partition{ + distributor: d, + } +} + +type partition struct { + distributor Distributor +} + +func (p *partition) Get(key string) (ReadAtCloser, io.WriteCloser, error) { + return p.distributor.GetCache(key).Get(key) +} + +func (p *partition) Remove(key string) error { + return p.distributor.GetCache(key).Remove(key) +} + +func (p *partition) Exists(key string) bool { + return p.distributor.GetCache(key).Exists(key) +} + +func (p *partition) Clean() error { + return p.distributor.Clean() +} diff --git a/libsq/core/ioz/fscache/example_test.go b/libsq/core/ioz/fscache/example_test.go new file mode 100644 index 000000000..5aa9e7266 --- /dev/null +++ b/libsq/core/ioz/fscache/example_test.go @@ -0,0 +1,69 @@ +package fscache + +import ( + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "time" +) + +func Example() { + // create the cache, keys expire after 1 hour. + c, err := New("./cache", 0755, time.Hour) + if err != nil { + log.Fatal(err.Error()) + } + + // wipe the cache when done + defer c.Clean() + + // Get() and it's streams can be called concurrently but just for example: + for i := 0; i < 3; i++ { + r, w, err := c.Get("stream") + if err != nil { + log.Fatal(err.Error()) + } + + if w != nil { // a new stream, write to it. + go func() { + w.Write([]byte("hello world\n")) + w.Close() + }() + } + + // the stream has started, read from it + io.Copy(os.Stdout, r) + r.Close() + } + // Output: + // hello world + // hello world + // hello world +} + +func ExampleHandler() { + c, err := New("./server", 0700, 0) + if err != nil { + log.Fatal(err.Error()) + } + defer c.Clean() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello Client") + }) + + ts := httptest.NewServer(Handler(c, handler)) + defer ts.Close() + + resp, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err.Error()) + } + io.Copy(os.Stdout, resp.Body) + resp.Body.Close() + // Output: + // Hello Client +} diff --git a/libsq/core/ioz/fscache/fileinfo.go b/libsq/core/ioz/fscache/fileinfo.go new file mode 100644 index 000000000..445fcfd50 --- /dev/null +++ b/libsq/core/ioz/fscache/fileinfo.go @@ -0,0 +1,52 @@ +package fscache + +import ( + "os" + "time" +) + +// FileInfo is just a wrapper around os.FileInfo which includes atime. +type FileInfo struct { + os.FileInfo + Atime time.Time +} + +type fileInfo struct { + name string + size int64 + fileMode os.FileMode + isDir bool + sys interface{} + wt time.Time +} + +func (f *fileInfo) Name() string { + return f.name +} + +func (f *fileInfo) Size() int64 { + return f.size +} + +func (f *fileInfo) Mode() os.FileMode { + return f.fileMode +} + +func (f *fileInfo) ModTime() time.Time { + return f.wt +} + +func (f *fileInfo) IsDir() bool { + return f.isDir +} + +func (f *fileInfo) Sys() interface{} { + return f.sys +} + +// AccessTime returns the last time the file was read. +// It will be used to check expiry of a file, and must be concurrent safe +// with modifications to the FileSystem (writes, reads etc.) +func (f *FileInfo) AccessTime() time.Time { + return f.Atime +} diff --git a/libsq/core/ioz/fscache/fs.go b/libsq/core/ioz/fscache/fs.go new file mode 100644 index 000000000..dad018382 --- /dev/null +++ b/libsq/core/ioz/fscache/fs.go @@ -0,0 +1,266 @@ +package fscache + +import ( + "bytes" + "crypto/md5" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" + + "github.com/djherbis/atime" + "github.com/djherbis/stream" +) + +// FileSystemStater implementers can provide FileInfo data about a named resource. +type FileSystemStater interface { + // Stat takes a File.Name() and returns FileInfo interface + Stat(name string) (FileInfo, error) +} + +// FileSystem is used as the source for a Cache. +type FileSystem interface { + // Stream FileSystem + stream.FileSystem + + FileSystemStater + + // Reload should look through the FileSystem and call the supplied fn + // with the key/filename pairs that are found. + Reload(func(key, name string)) error + + // RemoveAll should empty the FileSystem of all files. + RemoveAll() error +} + +// StandardFS is an implemenation of FileSystem which writes to the os Filesystem. +type StandardFS struct { + root string + init func() error + + // EncodeKey takes a 'name' given to Create and converts it into a + // the Filename that should be used. It should return 'true' if + // DecodeKey can convert the returned string back to the original 'name' + // and false otherwise. + // This must be set before the first call to Create. + EncodeKey func(string) (string, bool) + + // DecodeKey should convert a given Filename into the original 'name' given to + // EncodeKey, and return true if this conversion was possible. Returning false + // will cause it to try and lookup a stored 'encodedName.key' file which holds + // the original name. + DecodeKey func(string) (string, bool) +} + +// IdentityCodeKey works as both an EncodeKey and a DecodeKey func, which just returns +// it's given argument and true. This is expected to be used when your FSCache +// uses SetKeyMapper to ensure its internal km(key) value is already a valid filename path. +func IdentityCodeKey(key string) (string, bool) { return key, true } + +// NewFs returns a FileSystem rooted at directory dir. +// Dir is created with perms if it doesn't exist. +// This also uses the default EncodeKey/DecodeKey functions B64ORMD5HashEncodeKey/B64DecodeKey. +func NewFs(dir string, mode os.FileMode) (*StandardFS, error) { + fs := &StandardFS{ + root: dir, + init: func() error { + return os.MkdirAll(dir, mode) + }, + EncodeKey: B64OrMD5HashEncodeKey, + DecodeKey: B64DecodeKey, + } + return fs, fs.init() +} + +// Reload looks through the dir given to NewFs and returns every key, name pair (Create(key) => name = File.Name()) +// that is managed by this FileSystem. +func (fs *StandardFS) Reload(add func(key, name string)) error { + files, err := ioutil.ReadDir(fs.root) + if err != nil { + return err + } + + addfiles := make(map[string]struct { + os.FileInfo + key string + }) + + for _, f := range files { + + if strings.HasSuffix(f.Name(), ".key") { + continue + } + + key, err := fs.getKey(f.Name()) + if err != nil { + fs.Remove(filepath.Join(fs.root, f.Name())) + continue + } + fi, ok := addfiles[key] + + if !ok || fi.ModTime().Before(f.ModTime()) { + if ok { + fs.Remove(fi.Name()) + } + addfiles[key] = struct { + os.FileInfo + key string + }{ + FileInfo: f, + key: key, + } + } else { + fs.Remove(f.Name()) + } + + } + + for _, f := range addfiles { + path, err := filepath.Abs(filepath.Join(fs.root, f.Name())) + if err != nil { + return err + } + add(f.key, path) + } + + return nil +} + +// Create creates a File for the given 'name', it may not use the given name on the +// os filesystem, that depends on the implementation of EncodeKey used. +func (fs *StandardFS) Create(name string) (stream.File, error) { + name, err := fs.makeName(name) + if err != nil { + return nil, err + } + return fs.create(name) +} + +func (fs *StandardFS) create(name string) (stream.File, error) { + return os.OpenFile(filepath.Join(fs.root, name), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) +} + +// Open opens a stream.File for the given File.Name() returned by Create(). +func (fs *StandardFS) Open(name string) (stream.File, error) { + return os.Open(name) +} + +// Remove removes a stream.File for the given File.Name() returned by Create(). +func (fs *StandardFS) Remove(name string) error { + os.Remove(fmt.Sprintf("%s.key", name)) + return os.Remove(name) +} + +// RemoveAll deletes all files in the directory managed by this StandardFS. +// Warning that if you put files in this directory that were not created by +// StandardFS they will also be deleted. +func (fs *StandardFS) RemoveAll() error { + if err := os.RemoveAll(fs.root); err != nil { + return err + } + return fs.init() +} + +// AccessTimes returns atime and mtime for the given File.Name() returned by Create(). +func (fs *StandardFS) AccessTimes(name string) (rt, wt time.Time, err error) { + fi, err := os.Stat(name) + if err != nil { + return rt, wt, err + } + return atime.Get(fi), fi.ModTime(), nil +} + +// Stat returns FileInfo for the given File.Name() returned by Create(). +func (fs *StandardFS) Stat(name string) (FileInfo, error) { + stat, err := os.Stat(name) + if err != nil { + return FileInfo{}, err + } + + return FileInfo{FileInfo: stat, Atime: atime.Get(stat)}, nil +} + +const ( + saltSize = 8 + salt = "xxxxxxxx" // this is only important for sizing now. + maxShort = 20 + shortPrefix = "s" + longPrefix = "l" +) + +func tob64(s string) string { + buf := bytes.NewBufferString("") + enc := base64.NewEncoder(base64.URLEncoding, buf) + enc.Write([]byte(s)) + enc.Close() + return buf.String() +} + +func fromb64(s string) string { + buf := bytes.NewBufferString(s) + dec := base64.NewDecoder(base64.URLEncoding, buf) + out := bytes.NewBufferString("") + io.Copy(out, dec) + return out.String() +} + +// B64OrMD5HashEncodeKey converts a given key into a filesystem name-safe string +// and returns true iff it can be reversed with B64DecodeKey. +func B64OrMD5HashEncodeKey(key string) (string, bool) { + b64key := tob64(key) + // short name + if len(b64key) < maxShort { + return fmt.Sprintf("%s%s%s", shortPrefix, salt, b64key), true + } + + // long name + hash := md5.Sum([]byte(key)) + return fmt.Sprintf("%s%s%x", longPrefix, salt, hash[:]), false +} + +func (fs *StandardFS) makeName(key string) (string, error) { + name, decodable := fs.EncodeKey(key) + if decodable { + return name, nil + } + + // Name is not decodeable, store it. + f, err := fs.create(fmt.Sprintf("%s.key", name)) + if err != nil { + return "", err + } + _, err = f.Write([]byte(key)) + f.Close() + return name, err +} + +// B64DecodeKey converts a string y into x st. y, ok = B64OrMD5HashEncodeKey(x), and ok = true. +// Basically it should reverse B64OrMD5HashEncodeKey if B64OrMD5HashEncodeKey returned true. +func B64DecodeKey(name string) (string, bool) { + if strings.HasPrefix(name, shortPrefix) { + return fromb64(strings.TrimPrefix(name, shortPrefix)[saltSize:]), true + } + return "", false +} + +func (fs *StandardFS) getKey(name string) (string, error) { + if key, ok := fs.DecodeKey(name); ok { + return key, nil + } + + // long name + f, err := fs.Open(filepath.Join(fs.root, fmt.Sprintf("%s.key", name))) + if err != nil { + return "", err + } + defer f.Close() + key, err := ioutil.ReadAll(f) + if err != nil { + return "", err + } + return string(key), nil +} diff --git a/libsq/core/ioz/fscache/fscache.go b/libsq/core/ioz/fscache/fscache.go new file mode 100644 index 000000000..6de40a3b8 --- /dev/null +++ b/libsq/core/ioz/fscache/fscache.go @@ -0,0 +1,373 @@ +package fscache + +import ( + "fmt" + "io" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/djherbis/stream" +) + +// Cache works like a concurrent-safe map for streams. +type Cache interface { + // Get manages access to the streams in the cache. + // If the key does not exist, w != nil and you can start writing to the stream. + // If the key does exist, w == nil. + // r will always be non-nil as long as err == nil and you must close r when you're done reading. + // Get can be called concurrently, and writing and reading is concurrent safe. + Get(key string) (ReadAtCloser, io.WriteCloser, error) + + // Remove deletes the stream from the cache, blocking until the underlying + // file can be deleted (all active streams finish with it). + // It is safe to call Remove concurrently with Get. + Remove(key string) error + + // Exists checks if a key is in the cache. + // It is safe to call Exists concurrently with Get. + Exists(key string) bool + + // Clean will empty the cache and delete the cache folder. + // Clean is not safe to call while streams are being read/written. + Clean() error +} + +// FSCache is a Cache which uses a Filesystem to read/write cached data. +type FSCache struct { + mu sync.RWMutex + files map[string]fileStream + km func(string) string + fs FileSystem + haunter Haunter +} + +// SetKeyMapper will use the given function to transform any given Cache key into the result of km(key). +// This means that internally, the cache will only track km(key), and forget the original key. The consequences +// of this are that Enumerate will return km(key) instead of key, and Filesystem will give km(key) to Create +// and expect Reload() to return km(key). +// The purpose of this function is so that the internally managed key can be converted to a string that is +// allowed as a filesystem path. +func (c *FSCache) SetKeyMapper(km func(string) string) *FSCache { + c.mu.Lock() + defer c.mu.Unlock() + c.km = km + return c +} + +func (c *FSCache) mapKey(key string) string { + if c.km == nil { + return key + } + return c.km(key) +} + +// ReadAtCloser is an io.ReadCloser, and an io.ReaderAt. It supports both so that Range +// Requests are possible. +type ReadAtCloser interface { + io.ReadCloser + io.ReaderAt +} + +type fileStream interface { + next() (*CacheReader, error) + InUse() bool + io.WriteCloser + remove() error + Name() string +} + +// New creates a new Cache using NewFs(dir, perms). +// expiry is the duration after which an un-accessed key will be removed from +// the cache, a zero value expiro means never expire. +func New(dir string, perms os.FileMode, expiry time.Duration) (*FSCache, error) { + fs, err := NewFs(dir, perms) + if err != nil { + return nil, err + } + var grim Reaper + if expiry > 0 { + grim = &reaper{ + expiry: expiry, + period: expiry, + } + } + return NewCache(fs, grim) +} + +// NewCache creates a new Cache based on FileSystem fs. +// fs.Files() are loaded using the name they were created with as a key. +// Reaper is used to determine when files expire, nil means never expire. +func NewCache(fs FileSystem, grim Reaper) (*FSCache, error) { + if grim != nil { + return NewCacheWithHaunter(fs, NewReaperHaunterStrategy(grim)) + } + + return NewCacheWithHaunter(fs, nil) +} + +// NewCacheWithHaunter create a new Cache based on FileSystem fs. +// fs.Files() are loaded using the name they were created with as a key. +// Haunter is used to determine when files expire, nil means never expire. +func NewCacheWithHaunter(fs FileSystem, haunter Haunter) (*FSCache, error) { + c := &FSCache{ + files: make(map[string]fileStream), + haunter: haunter, + fs: fs, + } + err := c.load() + if err != nil { + return nil, err + } + if haunter != nil { + c.scheduleHaunt() + } + + return c, nil +} + +func (c *FSCache) scheduleHaunt() { + c.haunt() + time.AfterFunc(c.haunter.Next(), c.scheduleHaunt) +} + +func (c *FSCache) haunt() { + c.mu.Lock() + defer c.mu.Unlock() + + c.haunter.Haunt(&accessor{c: c}) +} + +func (c *FSCache) load() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.fs.Reload(func(key, name string) { + c.files[key] = c.oldFile(name) + }) +} + +// Exists returns true iff this key is in the Cache (may not be finished streaming). +func (c *FSCache) Exists(key string) bool { + c.mu.RLock() + defer c.mu.RUnlock() + _, ok := c.files[c.mapKey(key)] + return ok +} + +// Get obtains a ReadAtCloser for the given key, and may return a WriteCloser to write the original cache data +// if this is a cache-miss. +func (c *FSCache) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { + c.mu.RLock() + key = c.mapKey(key) + f, ok := c.files[key] + if ok { + r, err = f.next() + c.mu.RUnlock() + return r, nil, err + } + c.mu.RUnlock() + + c.mu.Lock() + defer c.mu.Unlock() + + f, ok = c.files[key] + if ok { + r, err = f.next() + return r, nil, err + } + + f, err = c.newFile(key) + if err != nil { + return nil, nil, err + } + + r, err = f.next() + if err != nil { + f.Close() + c.fs.Remove(f.Name()) + return nil, nil, err + } + + c.files[key] = f + + return r, f, err +} + +// Remove removes the specified key from the cache. +func (c *FSCache) Remove(key string) error { + c.mu.Lock() + key = c.mapKey(key) + f, ok := c.files[key] + delete(c.files, key) + c.mu.Unlock() + + if ok { + return f.remove() + } + return nil +} + +// Clean resets the cache removing all keys and data. +func (c *FSCache) Clean() error { + c.mu.Lock() + defer c.mu.Unlock() + c.files = make(map[string]fileStream) + return c.fs.RemoveAll() +} + +type accessor struct { + c *FSCache +} + +func (a *accessor) Stat(name string) (FileInfo, error) { + return a.c.fs.Stat(name) +} + +func (a *accessor) EnumerateEntries(enumerator func(key string, e Entry) bool) { + for k, f := range a.c.files { + if !enumerator(k, Entry{name: f.Name(), inUse: f.InUse()}) { + break + } + } +} + +func (a *accessor) RemoveFile(key string) { + key = a.c.mapKey(key) + f, ok := a.c.files[key] + delete(a.c.files, key) + if ok { + a.c.fs.Remove(f.Name()) + } +} + +type cachedFile struct { + handleCounter + stream *stream.Stream +} + +func (c *FSCache) newFile(name string) (fileStream, error) { + s, err := stream.NewStream(name, c.fs) + if err != nil { + return nil, err + } + cf := &cachedFile{ + stream: s, + } + cf.inc() + return cf, nil +} + +func (c *FSCache) oldFile(name string) fileStream { + return &reloadedFile{ + fs: c.fs, + name: name, + } +} + +type reloadedFile struct { + handleCounter + fs FileSystem + name string + io.WriteCloser // nop Write & Close methods. will never be called. +} + +func (f *reloadedFile) Name() string { return f.name } + +func (f *reloadedFile) remove() error { + f.waitUntilFree() + return f.fs.Remove(f.name) +} + +func (f *reloadedFile) next() (*CacheReader, error) { + r, err := f.fs.Open(f.name) + if err == nil { + f.inc() + } + return &CacheReader{ + ReadAtCloser: r, + cnt: &f.handleCounter, + }, err +} + +func (f *cachedFile) Name() string { return f.stream.Name() } + +func (f *cachedFile) remove() error { return f.stream.Remove() } + +func (f *cachedFile) next() (*CacheReader, error) { + reader, err := f.stream.NextReader() + if err != nil { + return nil, err + } + f.inc() + return &CacheReader{ + ReadAtCloser: reader, + cnt: &f.handleCounter, + }, nil +} + +func (f *cachedFile) Write(p []byte) (int, error) { + return f.stream.Write(p) +} + +func (f *cachedFile) Close() error { + defer f.dec() + return f.stream.Close() +} + +// CacheReader is a ReadAtCloser for a Cache key that also tracks open readers. +type CacheReader struct { + ReadAtCloser + cnt *handleCounter +} + +// Close frees the underlying ReadAtCloser and updates the open reader counter. +func (r *CacheReader) Close() error { + defer r.cnt.dec() + return r.ReadAtCloser.Close() +} + +// Size returns the current size of the stream being read, the boolean it +// returns is true iff the stream is done being written (otherwise Size may change). +// An error is returned if the Size fails to be computed or is not supported +// by the underlying filesystem. +func (r *CacheReader) Size() (int64, bool, error) { + switch v := r.ReadAtCloser.(type) { + case *stream.Reader: + size, done := v.Size() + return size, done, nil + + case interface{ Stat() (os.FileInfo, error) }: + fi, err := v.Stat() + if err != nil { + return 0, false, err + } + return fi.Size(), true, nil + + default: + return 0, false, fmt.Errorf("reader does not support stat") + } +} + +type handleCounter struct { + cnt int64 + grp sync.WaitGroup +} + +func (h *handleCounter) inc() { + h.grp.Add(1) + atomic.AddInt64(&h.cnt, 1) +} + +func (h *handleCounter) dec() { + atomic.AddInt64(&h.cnt, -1) + h.grp.Done() +} + +func (h *handleCounter) InUse() bool { + return atomic.LoadInt64(&h.cnt) > 0 +} + +func (h *handleCounter) waitUntilFree() { + h.grp.Wait() +} diff --git a/libsq/core/ioz/fscache/fscache_test.go b/libsq/core/ioz/fscache/fscache_test.go new file mode 100644 index 000000000..de125299a --- /dev/null +++ b/libsq/core/ioz/fscache/fscache_test.go @@ -0,0 +1,579 @@ +package fscache + +import ( + "bytes" + "crypto/md5" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" +) + +func createFile(name string) (*os.File, error) { + return os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) +} + +func init() { + c, _ := NewCache(NewMemFs(), nil) + go ListenAndServe(c, "localhost:10000") +} + +func testCaches(t *testing.T, run func(c Cache)) { + c, err := New("./cache", 0700, 1*time.Hour) + if err != nil { + t.Error(err.Error()) + return + } + run(c) + + c, err = NewCache(NewMemFs(), NewReaper(time.Hour, time.Hour)) + if err != nil { + t.Error(err.Error()) + return + } + run(c) + + c2, _ := NewCache(NewMemFs(), nil) + run(NewPartition(NewDistributor(c, c2))) + + lc := NewLayered(c, c2) + run(lc) + + rc := NewRemote("localhost:10000") + run(rc) + + fs, _ := NewFs("./cachex", 0700) + fs.EncodeKey = IdentityCodeKey + fs.DecodeKey = IdentityCodeKey + ck, _ := NewCache(fs, NewReaper(time.Hour, time.Hour)) + ck.SetKeyMapper(func(key string) string { + name, _ := B64OrMD5HashEncodeKey(key) + return name + }) + run(ck) +} + +func TestHandler(t *testing.T) { + testCaches(t, func(c Cache) { + defer c.Clean() + ts := httptest.NewServer(Handler(c, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello Client") + }))) + defer ts.Close() + + for i := 0; i < 3; i++ { + res, err := http.Get(ts.URL) + if err != nil { + t.Error(err.Error()) + t.FailNow() + } + p, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if !bytes.Equal([]byte("Hello Client\n"), p) { + t.Errorf("unexpected response %s", string(p)) + } + } + }) +} + +func TestMemFs(t *testing.T) { + fs := NewMemFs() + fs.Reload(func(key, name string) {}) // nop + if _, err := fs.Open("test"); err == nil { + t.Errorf("stream shouldn't exist") + } + fs.Remove("test") + + f, err := fs.Create("test") + if err != nil { + t.Errorf("failed to create test") + } + f.Write([]byte("hello")) + f.Close() + + r, err := fs.Open("test") + if err != nil { + t.Errorf("failed Open: %v", err) + } + p, err := ioutil.ReadAll(r) + if err != nil { + t.Errorf("failed ioutil.ReadAll: %v", err) + } + r.Close() + if !bytes.Equal(p, []byte("hello")) { + t.Errorf("expected hello, got %s", string(p)) + } + fs.RemoveAll() +} + +func TestLoadCleanup1(t *testing.T) { + os.Mkdir("./cache6", 0700) + f, err := createFile(filepath.Join("./cache6", "s11111111"+tob64("test"))) + if err != nil { + t.Error(err.Error()) + } + f.Close() + <-time.After(time.Second) + f, err = createFile(filepath.Join("./cache6", "s22222222"+tob64("test"))) + if err != nil { + t.Error(err.Error()) + } + f.Close() + + c, err := New("./cache6", 0700, 0) + if err != nil { + t.Error(err.Error()) + return + } + defer c.Clean() + + if !c.Exists("test") { + t.Errorf("expected test to exist") + } +} + +const longString = ` + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 +` + +func TestLoadCleanup2(t *testing.T) { + hash := md5.Sum([]byte(longString)) + name2 := fmt.Sprintf("%s%s%x", longPrefix, "22222222", hash[:]) + name1 := fmt.Sprintf("%s%s%x", longPrefix, "11111111", hash[:]) + + os.Mkdir("./cache7", 0700) + f, err := createFile(filepath.Join("./cache7", name2)) + if err != nil { + t.Error(err.Error()) + } + f.Close() + f, err = createFile(filepath.Join("./cache7", fmt.Sprintf("%s.key", name2))) + if err != nil { + t.Error(err.Error()) + } + f.Write([]byte(longString)) + f.Close() + <-time.After(time.Second) + f, err = createFile(filepath.Join("./cache7", name1)) + if err != nil { + t.Error(err.Error()) + } + f.Close() + f, err = createFile(filepath.Join("./cache7", fmt.Sprintf("%s.key", name1))) + if err != nil { + t.Error(err.Error()) + } + f.Write([]byte(longString)) + f.Close() + + c, err := New("./cache7", 0700, 0) + if err != nil { + t.Error(err.Error()) + return + } + defer c.Clean() + + if !c.Exists(longString) { + t.Errorf("expected test to exist") + } +} + +func TestReload(t *testing.T) { + dir, err := ioutil.TempDir("", "cache5") + if err != nil { + t.Fatalf("Failed to create TempDir: %v", err) + } + c, err := New(dir, 0700, 0) + if err != nil { + t.Error(err.Error()) + return + } + r, w, err := c.Get("stream") + if err != nil { + t.Error(err.Error()) + return + } + r.Close() + data := []byte("hello world\n") + w.Write(data) + w.Close() + + nc, err := New(dir, 0700, 0) + if err != nil { + t.Error(err.Error()) + return + } + defer nc.Clean() + + if !nc.Exists("stream") { + t.Fatalf("expected stream to be reloaded") + } + + r, w, err = nc.Get("stream") + if err != nil { + t.Fatal(err) + } + if w != nil { + t.Fatal("expected reloaded stream to not be writable") + } + + cr, ok := r.(*CacheReader) + if !ok { + t.Fatalf("CacheReader should be supported by a normal FS") + } + size, closed, err := cr.Size() + if err != nil { + t.Fatalf("Failed to get Size: %v", err) + } + if !closed { + t.Errorf("Expected stream to be closed.") + } + if size != int64(len(data)) { + t.Errorf("Expected size to be %v, but got %v", len(data), size) + } + + r.Close() + nc.Remove("stream") + if nc.Exists("stream") { + t.Errorf("expected stream to be removed") + } +} + +func TestLRUHaunterMaxItems(t *testing.T) { + + fs, err := NewFs("./cache1", 0700) + if err != nil { + t.Error(err.Error()) + t.FailNow() + } + + c, err := NewCacheWithHaunter(fs, NewLRUHaunterStrategy(NewLRUHaunter(3, 0, 400*time.Millisecond))) + + if err != nil { + t.Error(err.Error()) + return + } + defer c.Clean() + + for i := 0; i < 5; i++ { + name := fmt.Sprintf("stream-%v", i) + r, w, _ := c.Get(name) + w.Write([]byte("hello")) + w.Close() + io.Copy(ioutil.Discard, r) + + if !c.Exists(name) { + t.Errorf(name + " should exist") + } + + <-time.After(10 * time.Millisecond) + + err := r.Close() + if err != nil { + t.Error(err) + } + } + + <-time.After(400 * time.Millisecond) + + if c.Exists("stream-0") { + t.Errorf("stream-0 should have been scrubbed") + } + + if c.Exists("stream-1") { + t.Errorf("stream-1 should have been scrubbed") + } + + files, err := ioutil.ReadDir("./cache1") + if err != nil { + t.Error(err.Error()) + return + } + + if len(files) != 3 { + t.Errorf("expected 3 items in directory") + } +} + +func TestLRUHaunterMaxSize(t *testing.T) { + + fs, err := NewFs("./cache1", 0700) + if err != nil { + t.Error(err.Error()) + t.FailNow() + } + + c, err := NewCacheWithHaunter(fs, NewLRUHaunterStrategy(NewLRUHaunter(0, 24, 400*time.Millisecond))) + + if err != nil { + t.Error(err.Error()) + return + } + defer c.Clean() + + for i := 0; i < 5; i++ { + name := fmt.Sprintf("stream-%v", i) + r, w, _ := c.Get(name) + w.Write([]byte("hello")) + w.Close() + io.Copy(ioutil.Discard, r) + + if !c.Exists(name) { + t.Errorf(name + " should exist") + } + + <-time.After(10 * time.Millisecond) + + err := r.Close() + if err != nil { + t.Error(err) + } + } + + <-time.After(400 * time.Millisecond) + + if c.Exists("stream-0") { + t.Errorf("stream-0 should have been scrubbed") + } + + files, err := ioutil.ReadDir("./cache1") + if err != nil { + t.Error(err.Error()) + return + } + + if len(files) != 4 { + t.Errorf("expected 4 items in directory") + } +} + +func TestReaper(t *testing.T) { + fs, err := NewFs("./cache1", 0700) + if err != nil { + t.Error(err.Error()) + t.FailNow() + } + + c, err := NewCache(fs, NewReaper(0*time.Second, 100*time.Millisecond)) + if err != nil { + t.Fatal(err) + } + defer c.Clean() + + r, w, err := c.Get("stream") + if err != nil { + t.Fatal(err) + } + w.Write([]byte("hello")) + w.Close() + io.Copy(ioutil.Discard, r) + + if !c.Exists("stream") { + t.Errorf("stream should exist") + } + + <-time.After(200 * time.Millisecond) + + if !c.Exists("stream") { + t.Errorf("a file expired while in use, fail!") + } + r.Close() + + <-time.After(200 * time.Millisecond) + + if c.Exists("stream") { + t.Errorf("stream should have been reaped") + } + + files, err := ioutil.ReadDir("./cache1") + if err != nil { + t.Error(err.Error()) + return + } + + if len(files) > 0 { + t.Errorf("expected empty directory") + } +} + +func TestReaperNoExpire(t *testing.T) { + testCaches(t, func(c Cache) { + defer c.Clean() + r, w, err := c.Get("stream") + if err != nil { + t.Error(err.Error()) + t.FailNow() + } + w.Write([]byte("hello")) + w.Close() + io.Copy(ioutil.Discard, r) + r.Close() + + if !c.Exists("stream") { + t.Errorf("stream should exist") + } + + if lc, ok := c.(*FSCache); ok { + lc.haunt() + if !c.Exists("stream") { + t.Errorf("stream shouldn't have been reaped") + } + } + }) +} + +func TestSanity(t *testing.T) { + atLeastOneCacheReader := false + testCaches(t, func(c Cache) { + defer c.Clean() + + r, w, err := c.Get(longString) + if err != nil { + t.Error(err.Error()) + return + } + defer r.Close() + + want := []byte("hello world\n") + first := want[:5] + w.Write(first) + + cr, ok := r.(*CacheReader) + if ok { + atLeastOneCacheReader = true + size, closed, _ := cr.Size() + if closed { + t.Errorf("Expected stream to be open.") + } + if size != int64(len(first)) { + t.Errorf("Expected size to be %v, but got %v", len(first), size) + } + } + + second := want[5:] + w.Write(second) + + if ok { + atLeastOneCacheReader = true + size, closed, _ := cr.Size() + if closed { + t.Errorf("Expected stream to be open.") + } + if size != int64(len(want)) { + t.Errorf("Expected size to be %v, but got %v", len(want), size) + } + } + + w.Close() + + if ok { + atLeastOneCacheReader = true + size, closed, _ := cr.Size() + if !closed { + t.Errorf("Expected stream to be closed.") + } + if size != int64(len(want)) { + t.Errorf("Expected size to be %v, but got %v", len(want), size) + } + } + + buf := bytes.NewBuffer(nil) + _, err = io.Copy(buf, r) + if err != nil { + t.Error(err.Error()) + return + } + if !bytes.Equal(buf.Bytes(), want) { + t.Errorf("unexpected output %s", buf.Bytes()) + } + }) + if !atLeastOneCacheReader { + t.Errorf("None of the cache tests covered CacheReader!") + } +} + +func TestConcurrent(t *testing.T) { + testCaches(t, func(c Cache) { + defer c.Clean() + + r, w, err := c.Get("stream") + r.Close() + if err != nil { + t.Error(err.Error()) + return + } + go func() { + w.Write([]byte("hello")) + <-time.After(100 * time.Millisecond) + w.Write([]byte("world")) + w.Close() + }() + + if c.Exists("stream") { + r, _, err := c.Get("stream") + if err != nil { + t.Error(err.Error()) + return + } + buf := bytes.NewBuffer(nil) + io.Copy(buf, r) + r.Close() + if !bytes.Equal(buf.Bytes(), []byte("helloworld")) { + t.Errorf("unexpected output %s", buf.Bytes()) + } + } + }) +} + +func TestReuse(t *testing.T) { + testCaches(t, func(c Cache) { + for i := 0; i < 10; i++ { + r, w, err := c.Get(longString) + if err != nil { + t.Error(err.Error()) + return + } + + data := fmt.Sprintf("hello %d", i) + + if w != nil { + w.Write([]byte(data)) + w.Close() + } + + check(t, r, data) + r.Close() + + c.Clean() + } + }) +} + +func check(t *testing.T, r io.Reader, data string) { + buf := bytes.NewBuffer(nil) + _, err := io.Copy(buf, r) + if err != nil { + t.Error(err.Error()) + return + } + if !bytes.Equal(buf.Bytes(), []byte(data)) { + t.Errorf("unexpected output %q, want %q", buf.String(), data) + } +} diff --git a/libsq/core/ioz/fscache/handler.go b/libsq/core/ioz/fscache/handler.go new file mode 100644 index 000000000..8df85400c --- /dev/null +++ b/libsq/core/ioz/fscache/handler.go @@ -0,0 +1,41 @@ +package fscache + +import ( + "io" + "net/http" +) + +// Handler is a caching middle-ware for http Handlers. +// It responds to http requests via the passed http.Handler, and caches the response +// using the passed cache. The cache key for the request is the req.URL.String(). +// Note: It does not cache http headers. It is more efficient to set them yourself. +func Handler(c Cache, h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + url := req.URL.String() + r, w, err := c.Get(url) + if err != nil { + h.ServeHTTP(rw, req) + return + } + defer r.Close() + if w != nil { + go func() { + defer w.Close() + h.ServeHTTP(&respWrapper{ + ResponseWriter: rw, + Writer: w, + }, req) + }() + } + io.Copy(rw, r) + }) +} + +type respWrapper struct { + http.ResponseWriter + io.Writer +} + +func (r *respWrapper) Write(p []byte) (int, error) { + return r.Writer.Write(p) +} diff --git a/libsq/core/ioz/fscache/haunter.go b/libsq/core/ioz/fscache/haunter.go new file mode 100644 index 000000000..a8d038ce9 --- /dev/null +++ b/libsq/core/ioz/fscache/haunter.go @@ -0,0 +1,92 @@ +package fscache + +import ( + "time" +) + +// Entry represents a cached item. +type Entry struct { + name string + inUse bool +} + +// InUse returns if this Cache entry is in use. +func (e *Entry) InUse() bool { + return e.inUse +} + +// Name returns the File.Name() of this entry. +func (e *Entry) Name() string { + return e.name +} + +// CacheAccessor implementors provide ways to observe and interact with +// the cached entries, mainly used for cache-eviction. +type CacheAccessor interface { + FileSystemStater + EnumerateEntries(enumerator func(key string, e Entry) bool) + RemoveFile(key string) +} + +// Haunter implementors are used to perform cache-eviction (Next is how long to wait +// until next evication, Haunt preforms the eviction). +type Haunter interface { + Haunt(c CacheAccessor) + Next() time.Duration +} + +type reaperHaunterStrategy struct { + reaper Reaper +} + +type lruHaunterStrategy struct { + haunter LRUHaunter +} + +// NewLRUHaunterStrategy returns a simple scheduleHaunt which provides an implementation LRUHaunter strategy +func NewLRUHaunterStrategy(haunter LRUHaunter) Haunter { + return &lruHaunterStrategy{ + haunter: haunter, + } +} + +func (h *lruHaunterStrategy) Haunt(c CacheAccessor) { + for _, key := range h.haunter.Scrub(c) { + c.RemoveFile(key) + } + +} + +func (h *lruHaunterStrategy) Next() time.Duration { + return h.haunter.Next() +} + +// NewReaperHaunterStrategy returns a simple scheduleHaunt which provides an implementation Reaper strategy +func NewReaperHaunterStrategy(reaper Reaper) Haunter { + return &reaperHaunterStrategy{ + reaper: reaper, + } +} + +func (h *reaperHaunterStrategy) Haunt(c CacheAccessor) { + c.EnumerateEntries(func(key string, e Entry) bool { + if e.InUse() { + return true + } + + fileInfo, err := c.Stat(e.Name()) + if err != nil { + return true + } + + if h.reaper.Reap(key, fileInfo.AccessTime(), fileInfo.ModTime()) { + c.RemoveFile(key) + } + + return true + }) +} + +func (h *reaperHaunterStrategy) Next() time.Duration { + return h.reaper.Next() +} diff --git a/libsq/core/ioz/fscache/layers.go b/libsq/core/ioz/fscache/layers.go new file mode 100644 index 000000000..b0b283106 --- /dev/null +++ b/libsq/core/ioz/fscache/layers.go @@ -0,0 +1,128 @@ +package fscache + +import ( + "errors" + "io" + "sync" +) + +type layeredCache struct { + layers []Cache +} + +// NewLayered returns a Cache which stores its data in all the passed +// caches, when a key is requested it is loaded into all the caches above the first hit. +func NewLayered(caches ...Cache) Cache { + return &layeredCache{layers: caches} +} + +func (l *layeredCache) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { + var last ReadAtCloser + var writers []io.WriteCloser + + for i, layer := range l.layers { + r, w, err = layer.Get(key) + if err != nil { + if len(writers) > 0 { + last.Close() + multiWC(writers...).Close() + } + return nil, nil, err + } + + // hit + if w == nil { + if len(writers) > 0 { + go func(r io.ReadCloser) { + wc := multiWC(writers...) + defer r.Close() + defer wc.Close() + io.Copy(wc, r) + }(r) + return last, nil, nil + } + return r, nil, nil + } + + // miss + writers = append(writers, w) + + if i == len(l.layers)-1 { + if last != nil { + last.Close() + } + return r, multiWC(writers...), nil + } + + if last != nil { + last.Close() + } + last = r + } + + return nil, nil, errors.New("no caches") +} + +func (l *layeredCache) Remove(key string) error { + var grp sync.WaitGroup + // walk upwards so that lower layers don't + // restore upper layers on Get() + for i := len(l.layers) - 1; i >= 0; i-- { + grp.Add(1) + go func(layer Cache) { + defer grp.Done() + layer.Remove(key) + }(l.layers[i]) + } + grp.Wait() + return nil +} + +func (l *layeredCache) Exists(key string) bool { + for _, layer := range l.layers { + if layer.Exists(key) { + return true + } + } + return false +} + +func (l *layeredCache) Clean() (error) { + for _, layer := range l.layers { + if err := layer.Clean(); err != nil { + return err + } + } + return nil +} + +func multiWC(wc ...io.WriteCloser) io.WriteCloser { + if len(wc) == 0 { + return nil + } + + return &multiWriteCloser{ + writers: wc, + } +} + +type multiWriteCloser struct { + writers []io.WriteCloser +} + +func (t *multiWriteCloser) Write(p []byte) (n int, err error) { + for _, w := range t.writers { + n, err = w.Write(p) + if err != nil { + return + } + } + return len(p), nil +} + +func (t *multiWriteCloser) Close() error { + for _, w := range t.writers { + w.Close() + } + return nil +} diff --git a/libsq/core/ioz/fscache/lruhaunter.go b/libsq/core/ioz/fscache/lruhaunter.go new file mode 100644 index 000000000..7b90ef3a7 --- /dev/null +++ b/libsq/core/ioz/fscache/lruhaunter.go @@ -0,0 +1,137 @@ +package fscache + +import ( + "sort" + "time" +) + +type lruHaunterKV struct { + Key string + Value Entry +} + +// LRUHaunter is used to control when there are too many streams +// or the size of the streams is too big. +// It is called once right after loading, and then it is run +// again after every Next() period of time. +type LRUHaunter interface { + // Returns the amount of time to wait before the next scheduled Reaping. + Next() time.Duration + + // Given a CacheAccessor, return keys to reap list. + Scrub(c CacheAccessor) []string +} + +// NewLRUHaunter returns a simple haunter which runs every "period" +// and scrubs older files when the total file size is over maxSize or +// total item count is over maxItems. +// If maxItems or maxSize are 0, they won't be checked +func NewLRUHaunter(maxItems int, maxSize int64, period time.Duration) LRUHaunter { + return &lruHaunter{ + period: period, + maxItems: maxItems, + maxSize: maxSize, + } +} + +type lruHaunter struct { + period time.Duration + maxItems int + maxSize int64 +} + +func (j *lruHaunter) Next() time.Duration { + return j.period +} + +func (j *lruHaunter) Scrub(c CacheAccessor) (keysToReap []string) { + var count int + var size int64 + var okFiles []lruHaunterKV + + c.EnumerateEntries(func(key string, e Entry) bool { + if e.InUse() { + return true + } + + fileInfo, err := c.Stat(e.Name()) + if err != nil { + return true + } + + count++ + size = size + fileInfo.Size() + okFiles = append(okFiles, lruHaunterKV{ + Key: key, + Value: e, + }) + + return true + }) + + sort.Slice(okFiles, func(i, j int) bool { + iFileInfo, err := c.Stat(okFiles[i].Value.Name()) + if err != nil { + return false + } + + iLastRead := iFileInfo.AccessTime() + + jFileInfo, err := c.Stat(okFiles[j].Value.Name()) + if err != nil { + return false + } + + jLastRead := jFileInfo.AccessTime() + + return iLastRead.Before(jLastRead) + }) + + collectKeysToReapFn := func() bool { + var key *string + var err error + key, count, size, err = j.removeFirst(c, &okFiles, count, size) + if err != nil { + return false + } + if key != nil { + keysToReap = append(keysToReap, *key) + } + + return true + } + + if j.maxItems > 0 { + for count > j.maxItems { + if !collectKeysToReapFn() { + break + } + } + } + + if j.maxSize > 0 { + for size > j.maxSize { + if !collectKeysToReapFn() { + break + } + } + } + + return keysToReap +} + +func (j *lruHaunter) removeFirst(fsStater FileSystemStater, items *[]lruHaunterKV, count int, size int64) (*string, int, int64, error) { + var f lruHaunterKV + + f, *items = (*items)[0], (*items)[1:] + + fileInfo, err := fsStater.Stat(f.Value.Name()) + if err != nil { + return nil, count, size, err + } + + count-- + size = size - fileInfo.Size() + + return &f.Key, count, size, nil +} diff --git a/libsq/core/ioz/fscache/memfs.go b/libsq/core/ioz/fscache/memfs.go new file mode 100644 index 000000000..ddfb92cf3 --- /dev/null +++ b/libsq/core/ioz/fscache/memfs.go @@ -0,0 +1,147 @@ +package fscache + +import ( + "bytes" + "errors" + "io" + "os" + "sync" + "time" + + "github.com/djherbis/stream" +) + +type memFS struct { + mu sync.RWMutex + files map[string]*memFile +} + +// NewMemFs creates an in-memory FileSystem. +// It does not support persistence (Reload is a nop). +func NewMemFs() FileSystem { + return &memFS{ + files: make(map[string]*memFile), + } +} + +func (fs *memFS) Stat(name string) (FileInfo, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + f, ok := fs.files[name] + if !ok { + return FileInfo{}, errors.New("file has not been read") + } + + size := int64(len(f.Bytes())) + + return FileInfo{ + FileInfo: &fileInfo{ + name: name, + size: size, + fileMode: os.ModeIrregular, + isDir: false, + sys: nil, + wt: f.wt, + }, + Atime: f.rt, + }, nil +} + +func (fs *memFS) Reload(add func(key, name string)) error { + return nil +} + +func (fs *memFS) Create(key string) (stream.File, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if _, ok := fs.files[key]; ok { + return nil, errors.New("file exists") + } + file := &memFile{ + name: key, + r: bytes.NewBuffer(nil), + wt: time.Now(), + } + file.memReader.memFile = file + fs.files[key] = file + return file, nil +} + +func (fs *memFS) Open(name string) (stream.File, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if f, ok := fs.files[name]; ok { + f.rt = time.Now() + return &memReader{memFile: f}, nil + } + return nil, errors.New("file does not exist") +} + +func (fs *memFS) Remove(key string) error { + fs.mu.Lock() + defer fs.mu.Unlock() + delete(fs.files, key) + return nil +} + +func (fs *memFS) RemoveAll() error { + fs.mu.Lock() + defer fs.mu.Unlock() + fs.files = make(map[string]*memFile) + return nil +} + +type memFile struct { + mu sync.RWMutex + name string + r *bytes.Buffer + memReader + rt, wt time.Time +} + +func (f *memFile) Name() string { + return f.name +} + +func (f *memFile) Write(p []byte) (int, error) { + if len(p) > 0 { + f.mu.Lock() + defer f.mu.Unlock() + return f.r.Write(p) + } + return len(p), nil +} + +func (f *memFile) Bytes() []byte { + f.mu.RLock() + defer f.mu.RUnlock() + return f.r.Bytes() +} + +func (f *memFile) Close() error { + return nil +} + +type memReader struct { + *memFile + n int +} + +func (r *memReader) ReadAt(p []byte, off int64) (n int, err error) { + data := r.Bytes() + if int64(len(data)) < off { + return 0, io.EOF + } + n, err = bytes.NewReader(data[off:]).ReadAt(p, 0) + return n, err +} + +func (r *memReader) Read(p []byte) (n int, err error) { + n, err = bytes.NewReader(r.Bytes()[r.n:]).Read(p) + r.n += n + return n, err +} + +func (r *memReader) Close() error { + return nil +} diff --git a/libsq/core/ioz/fscache/reaper.go b/libsq/core/ioz/fscache/reaper.go new file mode 100644 index 000000000..d801202a7 --- /dev/null +++ b/libsq/core/ioz/fscache/reaper.go @@ -0,0 +1,37 @@ +package fscache + +import "time" + +// Reaper is used to control when streams expire from the cache. +// It is called once right after loading, and then it is run +// again after every Next() period of time. +type Reaper interface { + // Returns the amount of time to wait before the next scheduled Reaping. + Next() time.Duration + + // Given a key and the last r/w times of a file, return true + // to remove the file from the cache, false to keep it. + Reap(key string, lastRead, lastWrite time.Time) bool +} + +// NewReaper returns a simple reaper which runs every "Period" +// and reaps files which are older than "expiry". +func NewReaper(expiry, period time.Duration) Reaper { + return &reaper{ + expiry: expiry, + period: period, + } +} + +type reaper struct { + period time.Duration + expiry time.Duration +} + +func (g *reaper) Next() time.Duration { + return g.period +} + +func (g *reaper) Reap(key string, lastRead, lastWrite time.Time) bool { + return lastRead.Before(time.Now().Add(-g.expiry)) +} diff --git a/libsq/core/ioz/fscache/server.go b/libsq/core/ioz/fscache/server.go new file mode 100644 index 000000000..dba74aad3 --- /dev/null +++ b/libsq/core/ioz/fscache/server.go @@ -0,0 +1,206 @@ +package fscache + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" +) + +// ListenAndServe hosts a Cache for access via NewRemote +func ListenAndServe(c Cache, addr string) error { + return (&server{c: c}).ListenAndServe(addr) +} + +// NewRemote returns a Cache run via ListenAndServe +func NewRemote(raddr string) Cache { + return &remote{raddr: raddr} +} + +type server struct { + c Cache +} + +func (s *server) ListenAndServe(addr string) error { + l, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + for { + c, err := l.Accept() + if err != nil { + return err + } + + go s.Serve(c) + } +} + +const ( + actionGet = iota + actionRemove = iota + actionExists = iota + actionClean = iota +) + +func getKey(r io.Reader) string { + dec := newDecoder(r) + buf := bytes.NewBufferString("") + io.Copy(buf, dec) + return buf.String() +} + +func sendKey(w io.Writer, key string) { + enc := newEncoder(w) + enc.Write([]byte(key)) + enc.Close() +} + +func (s *server) Serve(c net.Conn) { + var action int + fmt.Fscanf(c, "%d\n", &action) + + switch action { + case actionGet: + s.get(c, getKey(c)) + case actionRemove: + s.c.Remove(getKey(c)) + case actionExists: + s.exists(c, getKey(c)) + case actionClean: + s.c.Clean() + } +} + +func (s *server) exists(c net.Conn, key string) { + if s.c.Exists(key) { + fmt.Fprintf(c, "%d\n", 1) + } else { + fmt.Fprintf(c, "%d\n", 0) + } +} + +func (s *server) get(c net.Conn, key string) { + r, w, err := s.c.Get(key) + if err != nil { + return // handle this better + } + defer r.Close() + + if w != nil { + go func() { + fmt.Fprintf(c, "%d\n", 1) + io.Copy(w, newDecoder(c)) + w.Close() + }() + } else { + fmt.Fprintf(c, "%d\n", 0) + } + + enc := newEncoder(c) + io.Copy(enc, r) + enc.Close() +} + +type remote struct { + raddr string +} + +func (rmt *remote) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { + c, err := net.Dial("tcp", rmt.raddr) + if err != nil { + return nil, nil, err + } + fmt.Fprintf(c, "%d\n", actionGet) + sendKey(c, key) + + var i int + fmt.Fscanf(c, "%d\n", &i) + + var ch chan struct{} + + switch i { + case 0: + ch = make(chan struct{}) // close net.Conn on reader close + case 1: + ch = make(chan struct{}, 1) // two closes before net.Conn close + + w = &safeCloser{ + c: c, + ch: ch, + w: newEncoder(c), + } + default: + return nil, nil, errors.New("bad bad bad") + } + + r = &safeCloser{ + c: c, + ch: ch, + r: newDecoder(c), + } + + return r, w, nil +} + +type safeCloser struct { + c net.Conn + ch chan<- struct{} + r ReadAtCloser + w io.WriteCloser +} + +func (s *safeCloser) ReadAt(p []byte, off int64) (int, error) { + return s.r.ReadAt(p, off) +} +func (s *safeCloser) Read(p []byte) (int, error) { return s.r.Read(p) } +func (s *safeCloser) Write(p []byte) (int, error) { return s.w.Write(p) } + +// Close only closes the underlying connection when ch is full. +func (s *safeCloser) Close() (err error) { + if s.r != nil { + err = s.r.Close() + } else if s.w != nil { + err = s.w.Close() + } + + select { + case s.ch <- struct{}{}: + return err + default: + return s.c.Close() + } +} + +func (rmt *remote) Exists(key string) bool { + c, err := net.Dial("tcp", rmt.raddr) + if err != nil { + return false + } + fmt.Fprintf(c, "%d\n", actionExists) + sendKey(c, key) + var i int + fmt.Fscanf(c, "%d\n", &i) + return i == 1 +} + +func (rmt *remote) Remove(key string) error { + c, err := net.Dial("tcp", rmt.raddr) + if err != nil { + return err + } + fmt.Fprintf(c, "%d\n", actionRemove) + sendKey(c, key) + return nil +} + +func (rmt *remote) Clean() error { + c, err := net.Dial("tcp", rmt.raddr) + if err != nil { + return err + } + fmt.Fprintf(c, "%d\n", actionClean) + return nil +} diff --git a/libsq/core/ioz/fscache/stream.go b/libsq/core/ioz/fscache/stream.go new file mode 100644 index 000000000..9cccb2483 --- /dev/null +++ b/libsq/core/ioz/fscache/stream.go @@ -0,0 +1,72 @@ +package fscache + +import ( + "encoding/json" + "errors" + "io" +) + +type decoder interface { + Decode(interface{}) error +} + +type encoder interface { + Encode(interface{}) error +} + +type pktReader struct { + dec decoder +} + +type pktWriter struct { + enc encoder +} + +type packet struct { + Err int + Data []byte +} + +const eof = 1 + +func (t *pktReader) ReadAt(p []byte, off int64) (n int, err error) { + // TODO not implemented + return 0, errors.New("not implemented") +} + +func (t *pktReader) Read(p []byte) (int, error) { + var pkt packet + err := t.dec.Decode(&pkt) + if err != nil { + return 0, err + } + if pkt.Err == eof { + return 0, io.EOF + } + return copy(p, pkt.Data), nil +} + +func (t *pktReader) Close() error { + return nil +} + +func (t *pktWriter) Write(p []byte) (int, error) { + pkt := packet{Data: p} + err := t.enc.Encode(pkt) + if err != nil { + return 0, err + } + return len(p), nil +} + +func (t *pktWriter) Close() error { + return t.enc.Encode(packet{Err: eof}) +} + +func newEncoder(w io.Writer) io.WriteCloser { + return &pktWriter{enc: json.NewEncoder(w)} +} + +func newDecoder(r io.Reader) ReadAtCloser { + return &pktReader{dec: json.NewDecoder(r)} +} From a52c12623e9805808ae6cfa4a71501e201f8f14a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 10:33:02 -0700 Subject: [PATCH 007/195] wip: adding fscache --- libsq/core/ioz/fscache/LICENSE | 22 - libsq/core/ioz/fscache/README.md | 93 ---- libsq/core/ioz/fscache/distrib.go | 85 ---- libsq/core/ioz/fscache/example_test.go | 69 --- libsq/core/ioz/fscache/fileinfo.go | 52 --- libsq/core/ioz/fscache/fs.go | 266 ------------ libsq/core/ioz/fscache/fscache.go | 373 ---------------- libsq/core/ioz/fscache/fscache_test.go | 579 ------------------------- libsq/core/ioz/fscache/handler.go | 41 -- libsq/core/ioz/fscache/haunter.go | 92 ---- libsq/core/ioz/fscache/layers.go | 128 ------ libsq/core/ioz/fscache/lruhaunter.go | 137 ------ libsq/core/ioz/fscache/memfs.go | 147 ------- libsq/core/ioz/fscache/reaper.go | 37 -- libsq/core/ioz/fscache/server.go | 206 --------- libsq/core/ioz/fscache/stream.go | 72 --- 16 files changed, 2399 deletions(-) delete mode 100644 libsq/core/ioz/fscache/LICENSE delete mode 100644 libsq/core/ioz/fscache/README.md delete mode 100644 libsq/core/ioz/fscache/distrib.go delete mode 100644 libsq/core/ioz/fscache/example_test.go delete mode 100644 libsq/core/ioz/fscache/fileinfo.go delete mode 100644 libsq/core/ioz/fscache/fs.go delete mode 100644 libsq/core/ioz/fscache/fscache.go delete mode 100644 libsq/core/ioz/fscache/fscache_test.go delete mode 100644 libsq/core/ioz/fscache/handler.go delete mode 100644 libsq/core/ioz/fscache/haunter.go delete mode 100644 libsq/core/ioz/fscache/layers.go delete mode 100644 libsq/core/ioz/fscache/lruhaunter.go delete mode 100644 libsq/core/ioz/fscache/memfs.go delete mode 100644 libsq/core/ioz/fscache/reaper.go delete mode 100644 libsq/core/ioz/fscache/server.go delete mode 100644 libsq/core/ioz/fscache/stream.go diff --git a/libsq/core/ioz/fscache/LICENSE b/libsq/core/ioz/fscache/LICENSE deleted file mode 100644 index 1e7b7cc09..000000000 --- a/libsq/core/ioz/fscache/LICENSE +++ /dev/null @@ -1,22 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2015 Dustin H - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - diff --git a/libsq/core/ioz/fscache/README.md b/libsq/core/ioz/fscache/README.md deleted file mode 100644 index 78b57ef35..000000000 --- a/libsq/core/ioz/fscache/README.md +++ /dev/null @@ -1,93 +0,0 @@ -fscache -========== - -[![GoDoc](https://godoc.org/github.com/djherbis/fscache?status.svg)](https://godoc.org/github.com/djherbis/fscache) -[![Release](https://img.shields.io/github/release/djherbis/fscache.svg)](https://github.com/djherbis/fscache/releases/latest) -[![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE.txt) -[![go test](https://github.com/djherbis/fscache/actions/workflows/go-test.yml/badge.svg)](https://github.com/djherbis/fscache/actions/workflows/go-test.yml) -[![Coverage Status](https://coveralls.io/repos/djherbis/fscache/badge.svg?branch=master)](https://coveralls.io/r/djherbis/fscache?branch=master) -[![Go Report Card](https://goreportcard.com/badge/github.com/djherbis/fscache)](https://goreportcard.com/report/github.com/djherbis/fscache) - -Usage ------------- -Streaming File Cache for #golang - -fscache allows multiple readers to read from a cache while its being written to. [blog post](https://djherbis.github.io/post/fscache/) - -Using the Cache directly: - -```go -package main - -import ( - "io" - "log" - "os" - "time" - - "gopkg.in/djherbis/fscache.v0" -) - -func main() { - - // create the cache, keys expire after 1 hour. - c, err := fscache.New("./cache", 0755, time.Hour) - if err != nil { - log.Fatal(err.Error()) - } - - // wipe the cache when done - defer c.Clean() - - // Get() and it's streams can be called concurrently but just for example: - for i := 0; i < 3; i++ { - r, w, err := c.Get("stream") - if err != nil { - log.Fatal(err.Error()) - } - - if w != nil { // a new stream, write to it. - go func(){ - w.Write([]byte("hello world\n")) - w.Close() - }() - } - - // the stream has started, read from it - io.Copy(os.Stdout, r) - r.Close() - } -} -``` - -A Caching Middle-ware: - -```go -package main - -import( - "net/http" - "time" - - "gopkg.in/djherbis/fscache.v0" -) - -func main(){ - c, err := fscache.New("./cache", 0700, 0) - if err != nil { - log.Fatal(err.Error()) - } - - handler := func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "%v: %s", time.Now(), "hello world") - } - - http.ListenAndServe(":8080", fscache.Handler(c, http.HandlerFunc(handler))) -} -``` - -Installation ------------- -```sh -go get gopkg.in/djherbis/fscache.v0 -``` diff --git a/libsq/core/ioz/fscache/distrib.go b/libsq/core/ioz/fscache/distrib.go deleted file mode 100644 index 60994cc58..000000000 --- a/libsq/core/ioz/fscache/distrib.go +++ /dev/null @@ -1,85 +0,0 @@ -package fscache - -import ( - "bytes" - "crypto/sha1" - "encoding/binary" - "io" -) - -// Distributor provides a way to partition keys into Caches. -type Distributor interface { - - // GetCache will always return the same Cache for the same key. - GetCache(key string) Cache - - // Clean should wipe all the caches this Distributor manages - Clean() error -} - -// stdDistribution distributes the keyspace evenly. -func stdDistribution(key string, n uint64) uint64 { - h := sha1.New() - io.WriteString(h, key) - buf := bytes.NewBuffer(h.Sum(nil)[:8]) - i, _ := binary.ReadUvarint(buf) - return i % n -} - -// NewDistributor returns a Distributor which evenly distributes the keyspace -// into the passed caches. -func NewDistributor(caches ...Cache) Distributor { - if len(caches) == 0 { - return nil - } - return &distrib{ - distribution: stdDistribution, - caches: caches, - size: uint64(len(caches)), - } -} - -type distrib struct { - distribution func(key string, n uint64) uint64 - caches []Cache - size uint64 -} - -func (d *distrib) GetCache(key string) Cache { - return d.caches[d.distribution(key, d.size)] -} - -// BUG(djherbis): Return an error if cleaning fails -func (d *distrib) Clean() error { - for _, c := range d.caches { - c.Clean() - } - return nil -} - -// NewPartition returns a Cache which uses the Caches defined by the passed Distributor. -func NewPartition(d Distributor) Cache { - return &partition{ - distributor: d, - } -} - -type partition struct { - distributor Distributor -} - -func (p *partition) Get(key string) (ReadAtCloser, io.WriteCloser, error) { - return p.distributor.GetCache(key).Get(key) -} - -func (p *partition) Remove(key string) error { - return p.distributor.GetCache(key).Remove(key) -} - -func (p *partition) Exists(key string) bool { - return p.distributor.GetCache(key).Exists(key) -} - -func (p *partition) Clean() error { - return p.distributor.Clean() -} diff --git a/libsq/core/ioz/fscache/example_test.go b/libsq/core/ioz/fscache/example_test.go deleted file mode 100644 index 5aa9e7266..000000000 --- a/libsq/core/ioz/fscache/example_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package fscache - -import ( - "fmt" - "io" - "log" - "net/http" - "net/http/httptest" - "os" - "time" -) - -func Example() { - // create the cache, keys expire after 1 hour. - c, err := New("./cache", 0755, time.Hour) - if err != nil { - log.Fatal(err.Error()) - } - - // wipe the cache when done - defer c.Clean() - - // Get() and it's streams can be called concurrently but just for example: - for i := 0; i < 3; i++ { - r, w, err := c.Get("stream") - if err != nil { - log.Fatal(err.Error()) - } - - if w != nil { // a new stream, write to it. - go func() { - w.Write([]byte("hello world\n")) - w.Close() - }() - } - - // the stream has started, read from it - io.Copy(os.Stdout, r) - r.Close() - } - // Output: - // hello world - // hello world - // hello world -} - -func ExampleHandler() { - c, err := New("./server", 0700, 0) - if err != nil { - log.Fatal(err.Error()) - } - defer c.Clean() - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Hello Client") - }) - - ts := httptest.NewServer(Handler(c, handler)) - defer ts.Close() - - resp, err := http.Get(ts.URL) - if err != nil { - log.Fatal(err.Error()) - } - io.Copy(os.Stdout, resp.Body) - resp.Body.Close() - // Output: - // Hello Client -} diff --git a/libsq/core/ioz/fscache/fileinfo.go b/libsq/core/ioz/fscache/fileinfo.go deleted file mode 100644 index 445fcfd50..000000000 --- a/libsq/core/ioz/fscache/fileinfo.go +++ /dev/null @@ -1,52 +0,0 @@ -package fscache - -import ( - "os" - "time" -) - -// FileInfo is just a wrapper around os.FileInfo which includes atime. -type FileInfo struct { - os.FileInfo - Atime time.Time -} - -type fileInfo struct { - name string - size int64 - fileMode os.FileMode - isDir bool - sys interface{} - wt time.Time -} - -func (f *fileInfo) Name() string { - return f.name -} - -func (f *fileInfo) Size() int64 { - return f.size -} - -func (f *fileInfo) Mode() os.FileMode { - return f.fileMode -} - -func (f *fileInfo) ModTime() time.Time { - return f.wt -} - -func (f *fileInfo) IsDir() bool { - return f.isDir -} - -func (f *fileInfo) Sys() interface{} { - return f.sys -} - -// AccessTime returns the last time the file was read. -// It will be used to check expiry of a file, and must be concurrent safe -// with modifications to the FileSystem (writes, reads etc.) -func (f *FileInfo) AccessTime() time.Time { - return f.Atime -} diff --git a/libsq/core/ioz/fscache/fs.go b/libsq/core/ioz/fscache/fs.go deleted file mode 100644 index dad018382..000000000 --- a/libsq/core/ioz/fscache/fs.go +++ /dev/null @@ -1,266 +0,0 @@ -package fscache - -import ( - "bytes" - "crypto/md5" - "encoding/base64" - "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" - "strings" - "time" - - "github.com/djherbis/atime" - "github.com/djherbis/stream" -) - -// FileSystemStater implementers can provide FileInfo data about a named resource. -type FileSystemStater interface { - // Stat takes a File.Name() and returns FileInfo interface - Stat(name string) (FileInfo, error) -} - -// FileSystem is used as the source for a Cache. -type FileSystem interface { - // Stream FileSystem - stream.FileSystem - - FileSystemStater - - // Reload should look through the FileSystem and call the supplied fn - // with the key/filename pairs that are found. - Reload(func(key, name string)) error - - // RemoveAll should empty the FileSystem of all files. - RemoveAll() error -} - -// StandardFS is an implemenation of FileSystem which writes to the os Filesystem. -type StandardFS struct { - root string - init func() error - - // EncodeKey takes a 'name' given to Create and converts it into a - // the Filename that should be used. It should return 'true' if - // DecodeKey can convert the returned string back to the original 'name' - // and false otherwise. - // This must be set before the first call to Create. - EncodeKey func(string) (string, bool) - - // DecodeKey should convert a given Filename into the original 'name' given to - // EncodeKey, and return true if this conversion was possible. Returning false - // will cause it to try and lookup a stored 'encodedName.key' file which holds - // the original name. - DecodeKey func(string) (string, bool) -} - -// IdentityCodeKey works as both an EncodeKey and a DecodeKey func, which just returns -// it's given argument and true. This is expected to be used when your FSCache -// uses SetKeyMapper to ensure its internal km(key) value is already a valid filename path. -func IdentityCodeKey(key string) (string, bool) { return key, true } - -// NewFs returns a FileSystem rooted at directory dir. -// Dir is created with perms if it doesn't exist. -// This also uses the default EncodeKey/DecodeKey functions B64ORMD5HashEncodeKey/B64DecodeKey. -func NewFs(dir string, mode os.FileMode) (*StandardFS, error) { - fs := &StandardFS{ - root: dir, - init: func() error { - return os.MkdirAll(dir, mode) - }, - EncodeKey: B64OrMD5HashEncodeKey, - DecodeKey: B64DecodeKey, - } - return fs, fs.init() -} - -// Reload looks through the dir given to NewFs and returns every key, name pair (Create(key) => name = File.Name()) -// that is managed by this FileSystem. -func (fs *StandardFS) Reload(add func(key, name string)) error { - files, err := ioutil.ReadDir(fs.root) - if err != nil { - return err - } - - addfiles := make(map[string]struct { - os.FileInfo - key string - }) - - for _, f := range files { - - if strings.HasSuffix(f.Name(), ".key") { - continue - } - - key, err := fs.getKey(f.Name()) - if err != nil { - fs.Remove(filepath.Join(fs.root, f.Name())) - continue - } - fi, ok := addfiles[key] - - if !ok || fi.ModTime().Before(f.ModTime()) { - if ok { - fs.Remove(fi.Name()) - } - addfiles[key] = struct { - os.FileInfo - key string - }{ - FileInfo: f, - key: key, - } - } else { - fs.Remove(f.Name()) - } - - } - - for _, f := range addfiles { - path, err := filepath.Abs(filepath.Join(fs.root, f.Name())) - if err != nil { - return err - } - add(f.key, path) - } - - return nil -} - -// Create creates a File for the given 'name', it may not use the given name on the -// os filesystem, that depends on the implementation of EncodeKey used. -func (fs *StandardFS) Create(name string) (stream.File, error) { - name, err := fs.makeName(name) - if err != nil { - return nil, err - } - return fs.create(name) -} - -func (fs *StandardFS) create(name string) (stream.File, error) { - return os.OpenFile(filepath.Join(fs.root, name), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) -} - -// Open opens a stream.File for the given File.Name() returned by Create(). -func (fs *StandardFS) Open(name string) (stream.File, error) { - return os.Open(name) -} - -// Remove removes a stream.File for the given File.Name() returned by Create(). -func (fs *StandardFS) Remove(name string) error { - os.Remove(fmt.Sprintf("%s.key", name)) - return os.Remove(name) -} - -// RemoveAll deletes all files in the directory managed by this StandardFS. -// Warning that if you put files in this directory that were not created by -// StandardFS they will also be deleted. -func (fs *StandardFS) RemoveAll() error { - if err := os.RemoveAll(fs.root); err != nil { - return err - } - return fs.init() -} - -// AccessTimes returns atime and mtime for the given File.Name() returned by Create(). -func (fs *StandardFS) AccessTimes(name string) (rt, wt time.Time, err error) { - fi, err := os.Stat(name) - if err != nil { - return rt, wt, err - } - return atime.Get(fi), fi.ModTime(), nil -} - -// Stat returns FileInfo for the given File.Name() returned by Create(). -func (fs *StandardFS) Stat(name string) (FileInfo, error) { - stat, err := os.Stat(name) - if err != nil { - return FileInfo{}, err - } - - return FileInfo{FileInfo: stat, Atime: atime.Get(stat)}, nil -} - -const ( - saltSize = 8 - salt = "xxxxxxxx" // this is only important for sizing now. - maxShort = 20 - shortPrefix = "s" - longPrefix = "l" -) - -func tob64(s string) string { - buf := bytes.NewBufferString("") - enc := base64.NewEncoder(base64.URLEncoding, buf) - enc.Write([]byte(s)) - enc.Close() - return buf.String() -} - -func fromb64(s string) string { - buf := bytes.NewBufferString(s) - dec := base64.NewDecoder(base64.URLEncoding, buf) - out := bytes.NewBufferString("") - io.Copy(out, dec) - return out.String() -} - -// B64OrMD5HashEncodeKey converts a given key into a filesystem name-safe string -// and returns true iff it can be reversed with B64DecodeKey. -func B64OrMD5HashEncodeKey(key string) (string, bool) { - b64key := tob64(key) - // short name - if len(b64key) < maxShort { - return fmt.Sprintf("%s%s%s", shortPrefix, salt, b64key), true - } - - // long name - hash := md5.Sum([]byte(key)) - return fmt.Sprintf("%s%s%x", longPrefix, salt, hash[:]), false -} - -func (fs *StandardFS) makeName(key string) (string, error) { - name, decodable := fs.EncodeKey(key) - if decodable { - return name, nil - } - - // Name is not decodeable, store it. - f, err := fs.create(fmt.Sprintf("%s.key", name)) - if err != nil { - return "", err - } - _, err = f.Write([]byte(key)) - f.Close() - return name, err -} - -// B64DecodeKey converts a string y into x st. y, ok = B64OrMD5HashEncodeKey(x), and ok = true. -// Basically it should reverse B64OrMD5HashEncodeKey if B64OrMD5HashEncodeKey returned true. -func B64DecodeKey(name string) (string, bool) { - if strings.HasPrefix(name, shortPrefix) { - return fromb64(strings.TrimPrefix(name, shortPrefix)[saltSize:]), true - } - return "", false -} - -func (fs *StandardFS) getKey(name string) (string, error) { - if key, ok := fs.DecodeKey(name); ok { - return key, nil - } - - // long name - f, err := fs.Open(filepath.Join(fs.root, fmt.Sprintf("%s.key", name))) - if err != nil { - return "", err - } - defer f.Close() - key, err := ioutil.ReadAll(f) - if err != nil { - return "", err - } - return string(key), nil -} diff --git a/libsq/core/ioz/fscache/fscache.go b/libsq/core/ioz/fscache/fscache.go deleted file mode 100644 index 6de40a3b8..000000000 --- a/libsq/core/ioz/fscache/fscache.go +++ /dev/null @@ -1,373 +0,0 @@ -package fscache - -import ( - "fmt" - "io" - "os" - "sync" - "sync/atomic" - "time" - - "github.com/djherbis/stream" -) - -// Cache works like a concurrent-safe map for streams. -type Cache interface { - // Get manages access to the streams in the cache. - // If the key does not exist, w != nil and you can start writing to the stream. - // If the key does exist, w == nil. - // r will always be non-nil as long as err == nil and you must close r when you're done reading. - // Get can be called concurrently, and writing and reading is concurrent safe. - Get(key string) (ReadAtCloser, io.WriteCloser, error) - - // Remove deletes the stream from the cache, blocking until the underlying - // file can be deleted (all active streams finish with it). - // It is safe to call Remove concurrently with Get. - Remove(key string) error - - // Exists checks if a key is in the cache. - // It is safe to call Exists concurrently with Get. - Exists(key string) bool - - // Clean will empty the cache and delete the cache folder. - // Clean is not safe to call while streams are being read/written. - Clean() error -} - -// FSCache is a Cache which uses a Filesystem to read/write cached data. -type FSCache struct { - mu sync.RWMutex - files map[string]fileStream - km func(string) string - fs FileSystem - haunter Haunter -} - -// SetKeyMapper will use the given function to transform any given Cache key into the result of km(key). -// This means that internally, the cache will only track km(key), and forget the original key. The consequences -// of this are that Enumerate will return km(key) instead of key, and Filesystem will give km(key) to Create -// and expect Reload() to return km(key). -// The purpose of this function is so that the internally managed key can be converted to a string that is -// allowed as a filesystem path. -func (c *FSCache) SetKeyMapper(km func(string) string) *FSCache { - c.mu.Lock() - defer c.mu.Unlock() - c.km = km - return c -} - -func (c *FSCache) mapKey(key string) string { - if c.km == nil { - return key - } - return c.km(key) -} - -// ReadAtCloser is an io.ReadCloser, and an io.ReaderAt. It supports both so that Range -// Requests are possible. -type ReadAtCloser interface { - io.ReadCloser - io.ReaderAt -} - -type fileStream interface { - next() (*CacheReader, error) - InUse() bool - io.WriteCloser - remove() error - Name() string -} - -// New creates a new Cache using NewFs(dir, perms). -// expiry is the duration after which an un-accessed key will be removed from -// the cache, a zero value expiro means never expire. -func New(dir string, perms os.FileMode, expiry time.Duration) (*FSCache, error) { - fs, err := NewFs(dir, perms) - if err != nil { - return nil, err - } - var grim Reaper - if expiry > 0 { - grim = &reaper{ - expiry: expiry, - period: expiry, - } - } - return NewCache(fs, grim) -} - -// NewCache creates a new Cache based on FileSystem fs. -// fs.Files() are loaded using the name they were created with as a key. -// Reaper is used to determine when files expire, nil means never expire. -func NewCache(fs FileSystem, grim Reaper) (*FSCache, error) { - if grim != nil { - return NewCacheWithHaunter(fs, NewReaperHaunterStrategy(grim)) - } - - return NewCacheWithHaunter(fs, nil) -} - -// NewCacheWithHaunter create a new Cache based on FileSystem fs. -// fs.Files() are loaded using the name they were created with as a key. -// Haunter is used to determine when files expire, nil means never expire. -func NewCacheWithHaunter(fs FileSystem, haunter Haunter) (*FSCache, error) { - c := &FSCache{ - files: make(map[string]fileStream), - haunter: haunter, - fs: fs, - } - err := c.load() - if err != nil { - return nil, err - } - if haunter != nil { - c.scheduleHaunt() - } - - return c, nil -} - -func (c *FSCache) scheduleHaunt() { - c.haunt() - time.AfterFunc(c.haunter.Next(), c.scheduleHaunt) -} - -func (c *FSCache) haunt() { - c.mu.Lock() - defer c.mu.Unlock() - - c.haunter.Haunt(&accessor{c: c}) -} - -func (c *FSCache) load() error { - c.mu.Lock() - defer c.mu.Unlock() - return c.fs.Reload(func(key, name string) { - c.files[key] = c.oldFile(name) - }) -} - -// Exists returns true iff this key is in the Cache (may not be finished streaming). -func (c *FSCache) Exists(key string) bool { - c.mu.RLock() - defer c.mu.RUnlock() - _, ok := c.files[c.mapKey(key)] - return ok -} - -// Get obtains a ReadAtCloser for the given key, and may return a WriteCloser to write the original cache data -// if this is a cache-miss. -func (c *FSCache) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { - c.mu.RLock() - key = c.mapKey(key) - f, ok := c.files[key] - if ok { - r, err = f.next() - c.mu.RUnlock() - return r, nil, err - } - c.mu.RUnlock() - - c.mu.Lock() - defer c.mu.Unlock() - - f, ok = c.files[key] - if ok { - r, err = f.next() - return r, nil, err - } - - f, err = c.newFile(key) - if err != nil { - return nil, nil, err - } - - r, err = f.next() - if err != nil { - f.Close() - c.fs.Remove(f.Name()) - return nil, nil, err - } - - c.files[key] = f - - return r, f, err -} - -// Remove removes the specified key from the cache. -func (c *FSCache) Remove(key string) error { - c.mu.Lock() - key = c.mapKey(key) - f, ok := c.files[key] - delete(c.files, key) - c.mu.Unlock() - - if ok { - return f.remove() - } - return nil -} - -// Clean resets the cache removing all keys and data. -func (c *FSCache) Clean() error { - c.mu.Lock() - defer c.mu.Unlock() - c.files = make(map[string]fileStream) - return c.fs.RemoveAll() -} - -type accessor struct { - c *FSCache -} - -func (a *accessor) Stat(name string) (FileInfo, error) { - return a.c.fs.Stat(name) -} - -func (a *accessor) EnumerateEntries(enumerator func(key string, e Entry) bool) { - for k, f := range a.c.files { - if !enumerator(k, Entry{name: f.Name(), inUse: f.InUse()}) { - break - } - } -} - -func (a *accessor) RemoveFile(key string) { - key = a.c.mapKey(key) - f, ok := a.c.files[key] - delete(a.c.files, key) - if ok { - a.c.fs.Remove(f.Name()) - } -} - -type cachedFile struct { - handleCounter - stream *stream.Stream -} - -func (c *FSCache) newFile(name string) (fileStream, error) { - s, err := stream.NewStream(name, c.fs) - if err != nil { - return nil, err - } - cf := &cachedFile{ - stream: s, - } - cf.inc() - return cf, nil -} - -func (c *FSCache) oldFile(name string) fileStream { - return &reloadedFile{ - fs: c.fs, - name: name, - } -} - -type reloadedFile struct { - handleCounter - fs FileSystem - name string - io.WriteCloser // nop Write & Close methods. will never be called. -} - -func (f *reloadedFile) Name() string { return f.name } - -func (f *reloadedFile) remove() error { - f.waitUntilFree() - return f.fs.Remove(f.name) -} - -func (f *reloadedFile) next() (*CacheReader, error) { - r, err := f.fs.Open(f.name) - if err == nil { - f.inc() - } - return &CacheReader{ - ReadAtCloser: r, - cnt: &f.handleCounter, - }, err -} - -func (f *cachedFile) Name() string { return f.stream.Name() } - -func (f *cachedFile) remove() error { return f.stream.Remove() } - -func (f *cachedFile) next() (*CacheReader, error) { - reader, err := f.stream.NextReader() - if err != nil { - return nil, err - } - f.inc() - return &CacheReader{ - ReadAtCloser: reader, - cnt: &f.handleCounter, - }, nil -} - -func (f *cachedFile) Write(p []byte) (int, error) { - return f.stream.Write(p) -} - -func (f *cachedFile) Close() error { - defer f.dec() - return f.stream.Close() -} - -// CacheReader is a ReadAtCloser for a Cache key that also tracks open readers. -type CacheReader struct { - ReadAtCloser - cnt *handleCounter -} - -// Close frees the underlying ReadAtCloser and updates the open reader counter. -func (r *CacheReader) Close() error { - defer r.cnt.dec() - return r.ReadAtCloser.Close() -} - -// Size returns the current size of the stream being read, the boolean it -// returns is true iff the stream is done being written (otherwise Size may change). -// An error is returned if the Size fails to be computed or is not supported -// by the underlying filesystem. -func (r *CacheReader) Size() (int64, bool, error) { - switch v := r.ReadAtCloser.(type) { - case *stream.Reader: - size, done := v.Size() - return size, done, nil - - case interface{ Stat() (os.FileInfo, error) }: - fi, err := v.Stat() - if err != nil { - return 0, false, err - } - return fi.Size(), true, nil - - default: - return 0, false, fmt.Errorf("reader does not support stat") - } -} - -type handleCounter struct { - cnt int64 - grp sync.WaitGroup -} - -func (h *handleCounter) inc() { - h.grp.Add(1) - atomic.AddInt64(&h.cnt, 1) -} - -func (h *handleCounter) dec() { - atomic.AddInt64(&h.cnt, -1) - h.grp.Done() -} - -func (h *handleCounter) InUse() bool { - return atomic.LoadInt64(&h.cnt) > 0 -} - -func (h *handleCounter) waitUntilFree() { - h.grp.Wait() -} diff --git a/libsq/core/ioz/fscache/fscache_test.go b/libsq/core/ioz/fscache/fscache_test.go deleted file mode 100644 index de125299a..000000000 --- a/libsq/core/ioz/fscache/fscache_test.go +++ /dev/null @@ -1,579 +0,0 @@ -package fscache - -import ( - "bytes" - "crypto/md5" - "fmt" - "io" - "io/ioutil" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - "time" -) - -func createFile(name string) (*os.File, error) { - return os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) -} - -func init() { - c, _ := NewCache(NewMemFs(), nil) - go ListenAndServe(c, "localhost:10000") -} - -func testCaches(t *testing.T, run func(c Cache)) { - c, err := New("./cache", 0700, 1*time.Hour) - if err != nil { - t.Error(err.Error()) - return - } - run(c) - - c, err = NewCache(NewMemFs(), NewReaper(time.Hour, time.Hour)) - if err != nil { - t.Error(err.Error()) - return - } - run(c) - - c2, _ := NewCache(NewMemFs(), nil) - run(NewPartition(NewDistributor(c, c2))) - - lc := NewLayered(c, c2) - run(lc) - - rc := NewRemote("localhost:10000") - run(rc) - - fs, _ := NewFs("./cachex", 0700) - fs.EncodeKey = IdentityCodeKey - fs.DecodeKey = IdentityCodeKey - ck, _ := NewCache(fs, NewReaper(time.Hour, time.Hour)) - ck.SetKeyMapper(func(key string) string { - name, _ := B64OrMD5HashEncodeKey(key) - return name - }) - run(ck) -} - -func TestHandler(t *testing.T) { - testCaches(t, func(c Cache) { - defer c.Clean() - ts := httptest.NewServer(Handler(c, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Hello Client") - }))) - defer ts.Close() - - for i := 0; i < 3; i++ { - res, err := http.Get(ts.URL) - if err != nil { - t.Error(err.Error()) - t.FailNow() - } - p, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - if !bytes.Equal([]byte("Hello Client\n"), p) { - t.Errorf("unexpected response %s", string(p)) - } - } - }) -} - -func TestMemFs(t *testing.T) { - fs := NewMemFs() - fs.Reload(func(key, name string) {}) // nop - if _, err := fs.Open("test"); err == nil { - t.Errorf("stream shouldn't exist") - } - fs.Remove("test") - - f, err := fs.Create("test") - if err != nil { - t.Errorf("failed to create test") - } - f.Write([]byte("hello")) - f.Close() - - r, err := fs.Open("test") - if err != nil { - t.Errorf("failed Open: %v", err) - } - p, err := ioutil.ReadAll(r) - if err != nil { - t.Errorf("failed ioutil.ReadAll: %v", err) - } - r.Close() - if !bytes.Equal(p, []byte("hello")) { - t.Errorf("expected hello, got %s", string(p)) - } - fs.RemoveAll() -} - -func TestLoadCleanup1(t *testing.T) { - os.Mkdir("./cache6", 0700) - f, err := createFile(filepath.Join("./cache6", "s11111111"+tob64("test"))) - if err != nil { - t.Error(err.Error()) - } - f.Close() - <-time.After(time.Second) - f, err = createFile(filepath.Join("./cache6", "s22222222"+tob64("test"))) - if err != nil { - t.Error(err.Error()) - } - f.Close() - - c, err := New("./cache6", 0700, 0) - if err != nil { - t.Error(err.Error()) - return - } - defer c.Clean() - - if !c.Exists("test") { - t.Errorf("expected test to exist") - } -} - -const longString = ` - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 -` - -func TestLoadCleanup2(t *testing.T) { - hash := md5.Sum([]byte(longString)) - name2 := fmt.Sprintf("%s%s%x", longPrefix, "22222222", hash[:]) - name1 := fmt.Sprintf("%s%s%x", longPrefix, "11111111", hash[:]) - - os.Mkdir("./cache7", 0700) - f, err := createFile(filepath.Join("./cache7", name2)) - if err != nil { - t.Error(err.Error()) - } - f.Close() - f, err = createFile(filepath.Join("./cache7", fmt.Sprintf("%s.key", name2))) - if err != nil { - t.Error(err.Error()) - } - f.Write([]byte(longString)) - f.Close() - <-time.After(time.Second) - f, err = createFile(filepath.Join("./cache7", name1)) - if err != nil { - t.Error(err.Error()) - } - f.Close() - f, err = createFile(filepath.Join("./cache7", fmt.Sprintf("%s.key", name1))) - if err != nil { - t.Error(err.Error()) - } - f.Write([]byte(longString)) - f.Close() - - c, err := New("./cache7", 0700, 0) - if err != nil { - t.Error(err.Error()) - return - } - defer c.Clean() - - if !c.Exists(longString) { - t.Errorf("expected test to exist") - } -} - -func TestReload(t *testing.T) { - dir, err := ioutil.TempDir("", "cache5") - if err != nil { - t.Fatalf("Failed to create TempDir: %v", err) - } - c, err := New(dir, 0700, 0) - if err != nil { - t.Error(err.Error()) - return - } - r, w, err := c.Get("stream") - if err != nil { - t.Error(err.Error()) - return - } - r.Close() - data := []byte("hello world\n") - w.Write(data) - w.Close() - - nc, err := New(dir, 0700, 0) - if err != nil { - t.Error(err.Error()) - return - } - defer nc.Clean() - - if !nc.Exists("stream") { - t.Fatalf("expected stream to be reloaded") - } - - r, w, err = nc.Get("stream") - if err != nil { - t.Fatal(err) - } - if w != nil { - t.Fatal("expected reloaded stream to not be writable") - } - - cr, ok := r.(*CacheReader) - if !ok { - t.Fatalf("CacheReader should be supported by a normal FS") - } - size, closed, err := cr.Size() - if err != nil { - t.Fatalf("Failed to get Size: %v", err) - } - if !closed { - t.Errorf("Expected stream to be closed.") - } - if size != int64(len(data)) { - t.Errorf("Expected size to be %v, but got %v", len(data), size) - } - - r.Close() - nc.Remove("stream") - if nc.Exists("stream") { - t.Errorf("expected stream to be removed") - } -} - -func TestLRUHaunterMaxItems(t *testing.T) { - - fs, err := NewFs("./cache1", 0700) - if err != nil { - t.Error(err.Error()) - t.FailNow() - } - - c, err := NewCacheWithHaunter(fs, NewLRUHaunterStrategy(NewLRUHaunter(3, 0, 400*time.Millisecond))) - - if err != nil { - t.Error(err.Error()) - return - } - defer c.Clean() - - for i := 0; i < 5; i++ { - name := fmt.Sprintf("stream-%v", i) - r, w, _ := c.Get(name) - w.Write([]byte("hello")) - w.Close() - io.Copy(ioutil.Discard, r) - - if !c.Exists(name) { - t.Errorf(name + " should exist") - } - - <-time.After(10 * time.Millisecond) - - err := r.Close() - if err != nil { - t.Error(err) - } - } - - <-time.After(400 * time.Millisecond) - - if c.Exists("stream-0") { - t.Errorf("stream-0 should have been scrubbed") - } - - if c.Exists("stream-1") { - t.Errorf("stream-1 should have been scrubbed") - } - - files, err := ioutil.ReadDir("./cache1") - if err != nil { - t.Error(err.Error()) - return - } - - if len(files) != 3 { - t.Errorf("expected 3 items in directory") - } -} - -func TestLRUHaunterMaxSize(t *testing.T) { - - fs, err := NewFs("./cache1", 0700) - if err != nil { - t.Error(err.Error()) - t.FailNow() - } - - c, err := NewCacheWithHaunter(fs, NewLRUHaunterStrategy(NewLRUHaunter(0, 24, 400*time.Millisecond))) - - if err != nil { - t.Error(err.Error()) - return - } - defer c.Clean() - - for i := 0; i < 5; i++ { - name := fmt.Sprintf("stream-%v", i) - r, w, _ := c.Get(name) - w.Write([]byte("hello")) - w.Close() - io.Copy(ioutil.Discard, r) - - if !c.Exists(name) { - t.Errorf(name + " should exist") - } - - <-time.After(10 * time.Millisecond) - - err := r.Close() - if err != nil { - t.Error(err) - } - } - - <-time.After(400 * time.Millisecond) - - if c.Exists("stream-0") { - t.Errorf("stream-0 should have been scrubbed") - } - - files, err := ioutil.ReadDir("./cache1") - if err != nil { - t.Error(err.Error()) - return - } - - if len(files) != 4 { - t.Errorf("expected 4 items in directory") - } -} - -func TestReaper(t *testing.T) { - fs, err := NewFs("./cache1", 0700) - if err != nil { - t.Error(err.Error()) - t.FailNow() - } - - c, err := NewCache(fs, NewReaper(0*time.Second, 100*time.Millisecond)) - if err != nil { - t.Fatal(err) - } - defer c.Clean() - - r, w, err := c.Get("stream") - if err != nil { - t.Fatal(err) - } - w.Write([]byte("hello")) - w.Close() - io.Copy(ioutil.Discard, r) - - if !c.Exists("stream") { - t.Errorf("stream should exist") - } - - <-time.After(200 * time.Millisecond) - - if !c.Exists("stream") { - t.Errorf("a file expired while in use, fail!") - } - r.Close() - - <-time.After(200 * time.Millisecond) - - if c.Exists("stream") { - t.Errorf("stream should have been reaped") - } - - files, err := ioutil.ReadDir("./cache1") - if err != nil { - t.Error(err.Error()) - return - } - - if len(files) > 0 { - t.Errorf("expected empty directory") - } -} - -func TestReaperNoExpire(t *testing.T) { - testCaches(t, func(c Cache) { - defer c.Clean() - r, w, err := c.Get("stream") - if err != nil { - t.Error(err.Error()) - t.FailNow() - } - w.Write([]byte("hello")) - w.Close() - io.Copy(ioutil.Discard, r) - r.Close() - - if !c.Exists("stream") { - t.Errorf("stream should exist") - } - - if lc, ok := c.(*FSCache); ok { - lc.haunt() - if !c.Exists("stream") { - t.Errorf("stream shouldn't have been reaped") - } - } - }) -} - -func TestSanity(t *testing.T) { - atLeastOneCacheReader := false - testCaches(t, func(c Cache) { - defer c.Clean() - - r, w, err := c.Get(longString) - if err != nil { - t.Error(err.Error()) - return - } - defer r.Close() - - want := []byte("hello world\n") - first := want[:5] - w.Write(first) - - cr, ok := r.(*CacheReader) - if ok { - atLeastOneCacheReader = true - size, closed, _ := cr.Size() - if closed { - t.Errorf("Expected stream to be open.") - } - if size != int64(len(first)) { - t.Errorf("Expected size to be %v, but got %v", len(first), size) - } - } - - second := want[5:] - w.Write(second) - - if ok { - atLeastOneCacheReader = true - size, closed, _ := cr.Size() - if closed { - t.Errorf("Expected stream to be open.") - } - if size != int64(len(want)) { - t.Errorf("Expected size to be %v, but got %v", len(want), size) - } - } - - w.Close() - - if ok { - atLeastOneCacheReader = true - size, closed, _ := cr.Size() - if !closed { - t.Errorf("Expected stream to be closed.") - } - if size != int64(len(want)) { - t.Errorf("Expected size to be %v, but got %v", len(want), size) - } - } - - buf := bytes.NewBuffer(nil) - _, err = io.Copy(buf, r) - if err != nil { - t.Error(err.Error()) - return - } - if !bytes.Equal(buf.Bytes(), want) { - t.Errorf("unexpected output %s", buf.Bytes()) - } - }) - if !atLeastOneCacheReader { - t.Errorf("None of the cache tests covered CacheReader!") - } -} - -func TestConcurrent(t *testing.T) { - testCaches(t, func(c Cache) { - defer c.Clean() - - r, w, err := c.Get("stream") - r.Close() - if err != nil { - t.Error(err.Error()) - return - } - go func() { - w.Write([]byte("hello")) - <-time.After(100 * time.Millisecond) - w.Write([]byte("world")) - w.Close() - }() - - if c.Exists("stream") { - r, _, err := c.Get("stream") - if err != nil { - t.Error(err.Error()) - return - } - buf := bytes.NewBuffer(nil) - io.Copy(buf, r) - r.Close() - if !bytes.Equal(buf.Bytes(), []byte("helloworld")) { - t.Errorf("unexpected output %s", buf.Bytes()) - } - } - }) -} - -func TestReuse(t *testing.T) { - testCaches(t, func(c Cache) { - for i := 0; i < 10; i++ { - r, w, err := c.Get(longString) - if err != nil { - t.Error(err.Error()) - return - } - - data := fmt.Sprintf("hello %d", i) - - if w != nil { - w.Write([]byte(data)) - w.Close() - } - - check(t, r, data) - r.Close() - - c.Clean() - } - }) -} - -func check(t *testing.T, r io.Reader, data string) { - buf := bytes.NewBuffer(nil) - _, err := io.Copy(buf, r) - if err != nil { - t.Error(err.Error()) - return - } - if !bytes.Equal(buf.Bytes(), []byte(data)) { - t.Errorf("unexpected output %q, want %q", buf.String(), data) - } -} diff --git a/libsq/core/ioz/fscache/handler.go b/libsq/core/ioz/fscache/handler.go deleted file mode 100644 index 8df85400c..000000000 --- a/libsq/core/ioz/fscache/handler.go +++ /dev/null @@ -1,41 +0,0 @@ -package fscache - -import ( - "io" - "net/http" -) - -// Handler is a caching middle-ware for http Handlers. -// It responds to http requests via the passed http.Handler, and caches the response -// using the passed cache. The cache key for the request is the req.URL.String(). -// Note: It does not cache http headers. It is more efficient to set them yourself. -func Handler(c Cache, h http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - url := req.URL.String() - r, w, err := c.Get(url) - if err != nil { - h.ServeHTTP(rw, req) - return - } - defer r.Close() - if w != nil { - go func() { - defer w.Close() - h.ServeHTTP(&respWrapper{ - ResponseWriter: rw, - Writer: w, - }, req) - }() - } - io.Copy(rw, r) - }) -} - -type respWrapper struct { - http.ResponseWriter - io.Writer -} - -func (r *respWrapper) Write(p []byte) (int, error) { - return r.Writer.Write(p) -} diff --git a/libsq/core/ioz/fscache/haunter.go b/libsq/core/ioz/fscache/haunter.go deleted file mode 100644 index a8d038ce9..000000000 --- a/libsq/core/ioz/fscache/haunter.go +++ /dev/null @@ -1,92 +0,0 @@ -package fscache - -import ( - "time" -) - -// Entry represents a cached item. -type Entry struct { - name string - inUse bool -} - -// InUse returns if this Cache entry is in use. -func (e *Entry) InUse() bool { - return e.inUse -} - -// Name returns the File.Name() of this entry. -func (e *Entry) Name() string { - return e.name -} - -// CacheAccessor implementors provide ways to observe and interact with -// the cached entries, mainly used for cache-eviction. -type CacheAccessor interface { - FileSystemStater - EnumerateEntries(enumerator func(key string, e Entry) bool) - RemoveFile(key string) -} - -// Haunter implementors are used to perform cache-eviction (Next is how long to wait -// until next evication, Haunt preforms the eviction). -type Haunter interface { - Haunt(c CacheAccessor) - Next() time.Duration -} - -type reaperHaunterStrategy struct { - reaper Reaper -} - -type lruHaunterStrategy struct { - haunter LRUHaunter -} - -// NewLRUHaunterStrategy returns a simple scheduleHaunt which provides an implementation LRUHaunter strategy -func NewLRUHaunterStrategy(haunter LRUHaunter) Haunter { - return &lruHaunterStrategy{ - haunter: haunter, - } -} - -func (h *lruHaunterStrategy) Haunt(c CacheAccessor) { - for _, key := range h.haunter.Scrub(c) { - c.RemoveFile(key) - } - -} - -func (h *lruHaunterStrategy) Next() time.Duration { - return h.haunter.Next() -} - -// NewReaperHaunterStrategy returns a simple scheduleHaunt which provides an implementation Reaper strategy -func NewReaperHaunterStrategy(reaper Reaper) Haunter { - return &reaperHaunterStrategy{ - reaper: reaper, - } -} - -func (h *reaperHaunterStrategy) Haunt(c CacheAccessor) { - c.EnumerateEntries(func(key string, e Entry) bool { - if e.InUse() { - return true - } - - fileInfo, err := c.Stat(e.Name()) - if err != nil { - return true - } - - if h.reaper.Reap(key, fileInfo.AccessTime(), fileInfo.ModTime()) { - c.RemoveFile(key) - } - - return true - }) -} - -func (h *reaperHaunterStrategy) Next() time.Duration { - return h.reaper.Next() -} diff --git a/libsq/core/ioz/fscache/layers.go b/libsq/core/ioz/fscache/layers.go deleted file mode 100644 index b0b283106..000000000 --- a/libsq/core/ioz/fscache/layers.go +++ /dev/null @@ -1,128 +0,0 @@ -package fscache - -import ( - "errors" - "io" - "sync" -) - -type layeredCache struct { - layers []Cache -} - -// NewLayered returns a Cache which stores its data in all the passed -// caches, when a key is requested it is loaded into all the caches above the first hit. -func NewLayered(caches ...Cache) Cache { - return &layeredCache{layers: caches} -} - -func (l *layeredCache) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { - var last ReadAtCloser - var writers []io.WriteCloser - - for i, layer := range l.layers { - r, w, err = layer.Get(key) - if err != nil { - if len(writers) > 0 { - last.Close() - multiWC(writers...).Close() - } - return nil, nil, err - } - - // hit - if w == nil { - if len(writers) > 0 { - go func(r io.ReadCloser) { - wc := multiWC(writers...) - defer r.Close() - defer wc.Close() - io.Copy(wc, r) - }(r) - return last, nil, nil - } - return r, nil, nil - } - - // miss - writers = append(writers, w) - - if i == len(l.layers)-1 { - if last != nil { - last.Close() - } - return r, multiWC(writers...), nil - } - - if last != nil { - last.Close() - } - last = r - } - - return nil, nil, errors.New("no caches") -} - -func (l *layeredCache) Remove(key string) error { - var grp sync.WaitGroup - // walk upwards so that lower layers don't - // restore upper layers on Get() - for i := len(l.layers) - 1; i >= 0; i-- { - grp.Add(1) - go func(layer Cache) { - defer grp.Done() - layer.Remove(key) - }(l.layers[i]) - } - grp.Wait() - return nil -} - -func (l *layeredCache) Exists(key string) bool { - for _, layer := range l.layers { - if layer.Exists(key) { - return true - } - } - return false -} - -func (l *layeredCache) Clean() (error) { - for _, layer := range l.layers { - if err := layer.Clean(); err != nil { - return err - } - } - return nil -} - -func multiWC(wc ...io.WriteCloser) io.WriteCloser { - if len(wc) == 0 { - return nil - } - - return &multiWriteCloser{ - writers: wc, - } -} - -type multiWriteCloser struct { - writers []io.WriteCloser -} - -func (t *multiWriteCloser) Write(p []byte) (n int, err error) { - for _, w := range t.writers { - n, err = w.Write(p) - if err != nil { - return - } - } - return len(p), nil -} - -func (t *multiWriteCloser) Close() error { - for _, w := range t.writers { - w.Close() - } - return nil -} diff --git a/libsq/core/ioz/fscache/lruhaunter.go b/libsq/core/ioz/fscache/lruhaunter.go deleted file mode 100644 index 7b90ef3a7..000000000 --- a/libsq/core/ioz/fscache/lruhaunter.go +++ /dev/null @@ -1,137 +0,0 @@ -package fscache - -import ( - "sort" - "time" -) - -type lruHaunterKV struct { - Key string - Value Entry -} - -// LRUHaunter is used to control when there are too many streams -// or the size of the streams is too big. -// It is called once right after loading, and then it is run -// again after every Next() period of time. -type LRUHaunter interface { - // Returns the amount of time to wait before the next scheduled Reaping. - Next() time.Duration - - // Given a CacheAccessor, return keys to reap list. - Scrub(c CacheAccessor) []string -} - -// NewLRUHaunter returns a simple haunter which runs every "period" -// and scrubs older files when the total file size is over maxSize or -// total item count is over maxItems. -// If maxItems or maxSize are 0, they won't be checked -func NewLRUHaunter(maxItems int, maxSize int64, period time.Duration) LRUHaunter { - return &lruHaunter{ - period: period, - maxItems: maxItems, - maxSize: maxSize, - } -} - -type lruHaunter struct { - period time.Duration - maxItems int - maxSize int64 -} - -func (j *lruHaunter) Next() time.Duration { - return j.period -} - -func (j *lruHaunter) Scrub(c CacheAccessor) (keysToReap []string) { - var count int - var size int64 - var okFiles []lruHaunterKV - - c.EnumerateEntries(func(key string, e Entry) bool { - if e.InUse() { - return true - } - - fileInfo, err := c.Stat(e.Name()) - if err != nil { - return true - } - - count++ - size = size + fileInfo.Size() - okFiles = append(okFiles, lruHaunterKV{ - Key: key, - Value: e, - }) - - return true - }) - - sort.Slice(okFiles, func(i, j int) bool { - iFileInfo, err := c.Stat(okFiles[i].Value.Name()) - if err != nil { - return false - } - - iLastRead := iFileInfo.AccessTime() - - jFileInfo, err := c.Stat(okFiles[j].Value.Name()) - if err != nil { - return false - } - - jLastRead := jFileInfo.AccessTime() - - return iLastRead.Before(jLastRead) - }) - - collectKeysToReapFn := func() bool { - var key *string - var err error - key, count, size, err = j.removeFirst(c, &okFiles, count, size) - if err != nil { - return false - } - if key != nil { - keysToReap = append(keysToReap, *key) - } - - return true - } - - if j.maxItems > 0 { - for count > j.maxItems { - if !collectKeysToReapFn() { - break - } - } - } - - if j.maxSize > 0 { - for size > j.maxSize { - if !collectKeysToReapFn() { - break - } - } - } - - return keysToReap -} - -func (j *lruHaunter) removeFirst(fsStater FileSystemStater, items *[]lruHaunterKV, count int, size int64) (*string, int, int64, error) { - var f lruHaunterKV - - f, *items = (*items)[0], (*items)[1:] - - fileInfo, err := fsStater.Stat(f.Value.Name()) - if err != nil { - return nil, count, size, err - } - - count-- - size = size - fileInfo.Size() - - return &f.Key, count, size, nil -} diff --git a/libsq/core/ioz/fscache/memfs.go b/libsq/core/ioz/fscache/memfs.go deleted file mode 100644 index ddfb92cf3..000000000 --- a/libsq/core/ioz/fscache/memfs.go +++ /dev/null @@ -1,147 +0,0 @@ -package fscache - -import ( - "bytes" - "errors" - "io" - "os" - "sync" - "time" - - "github.com/djherbis/stream" -) - -type memFS struct { - mu sync.RWMutex - files map[string]*memFile -} - -// NewMemFs creates an in-memory FileSystem. -// It does not support persistence (Reload is a nop). -func NewMemFs() FileSystem { - return &memFS{ - files: make(map[string]*memFile), - } -} - -func (fs *memFS) Stat(name string) (FileInfo, error) { - fs.mu.RLock() - defer fs.mu.RUnlock() - f, ok := fs.files[name] - if !ok { - return FileInfo{}, errors.New("file has not been read") - } - - size := int64(len(f.Bytes())) - - return FileInfo{ - FileInfo: &fileInfo{ - name: name, - size: size, - fileMode: os.ModeIrregular, - isDir: false, - sys: nil, - wt: f.wt, - }, - Atime: f.rt, - }, nil -} - -func (fs *memFS) Reload(add func(key, name string)) error { - return nil -} - -func (fs *memFS) Create(key string) (stream.File, error) { - fs.mu.Lock() - defer fs.mu.Unlock() - if _, ok := fs.files[key]; ok { - return nil, errors.New("file exists") - } - file := &memFile{ - name: key, - r: bytes.NewBuffer(nil), - wt: time.Now(), - } - file.memReader.memFile = file - fs.files[key] = file - return file, nil -} - -func (fs *memFS) Open(name string) (stream.File, error) { - fs.mu.Lock() - defer fs.mu.Unlock() - if f, ok := fs.files[name]; ok { - f.rt = time.Now() - return &memReader{memFile: f}, nil - } - return nil, errors.New("file does not exist") -} - -func (fs *memFS) Remove(key string) error { - fs.mu.Lock() - defer fs.mu.Unlock() - delete(fs.files, key) - return nil -} - -func (fs *memFS) RemoveAll() error { - fs.mu.Lock() - defer fs.mu.Unlock() - fs.files = make(map[string]*memFile) - return nil -} - -type memFile struct { - mu sync.RWMutex - name string - r *bytes.Buffer - memReader - rt, wt time.Time -} - -func (f *memFile) Name() string { - return f.name -} - -func (f *memFile) Write(p []byte) (int, error) { - if len(p) > 0 { - f.mu.Lock() - defer f.mu.Unlock() - return f.r.Write(p) - } - return len(p), nil -} - -func (f *memFile) Bytes() []byte { - f.mu.RLock() - defer f.mu.RUnlock() - return f.r.Bytes() -} - -func (f *memFile) Close() error { - return nil -} - -type memReader struct { - *memFile - n int -} - -func (r *memReader) ReadAt(p []byte, off int64) (n int, err error) { - data := r.Bytes() - if int64(len(data)) < off { - return 0, io.EOF - } - n, err = bytes.NewReader(data[off:]).ReadAt(p, 0) - return n, err -} - -func (r *memReader) Read(p []byte) (n int, err error) { - n, err = bytes.NewReader(r.Bytes()[r.n:]).Read(p) - r.n += n - return n, err -} - -func (r *memReader) Close() error { - return nil -} diff --git a/libsq/core/ioz/fscache/reaper.go b/libsq/core/ioz/fscache/reaper.go deleted file mode 100644 index d801202a7..000000000 --- a/libsq/core/ioz/fscache/reaper.go +++ /dev/null @@ -1,37 +0,0 @@ -package fscache - -import "time" - -// Reaper is used to control when streams expire from the cache. -// It is called once right after loading, and then it is run -// again after every Next() period of time. -type Reaper interface { - // Returns the amount of time to wait before the next scheduled Reaping. - Next() time.Duration - - // Given a key and the last r/w times of a file, return true - // to remove the file from the cache, false to keep it. - Reap(key string, lastRead, lastWrite time.Time) bool -} - -// NewReaper returns a simple reaper which runs every "Period" -// and reaps files which are older than "expiry". -func NewReaper(expiry, period time.Duration) Reaper { - return &reaper{ - expiry: expiry, - period: period, - } -} - -type reaper struct { - period time.Duration - expiry time.Duration -} - -func (g *reaper) Next() time.Duration { - return g.period -} - -func (g *reaper) Reap(key string, lastRead, lastWrite time.Time) bool { - return lastRead.Before(time.Now().Add(-g.expiry)) -} diff --git a/libsq/core/ioz/fscache/server.go b/libsq/core/ioz/fscache/server.go deleted file mode 100644 index dba74aad3..000000000 --- a/libsq/core/ioz/fscache/server.go +++ /dev/null @@ -1,206 +0,0 @@ -package fscache - -import ( - "bytes" - "errors" - "fmt" - "io" - "net" -) - -// ListenAndServe hosts a Cache for access via NewRemote -func ListenAndServe(c Cache, addr string) error { - return (&server{c: c}).ListenAndServe(addr) -} - -// NewRemote returns a Cache run via ListenAndServe -func NewRemote(raddr string) Cache { - return &remote{raddr: raddr} -} - -type server struct { - c Cache -} - -func (s *server) ListenAndServe(addr string) error { - l, err := net.Listen("tcp", addr) - if err != nil { - return err - } - - for { - c, err := l.Accept() - if err != nil { - return err - } - - go s.Serve(c) - } -} - -const ( - actionGet = iota - actionRemove = iota - actionExists = iota - actionClean = iota -) - -func getKey(r io.Reader) string { - dec := newDecoder(r) - buf := bytes.NewBufferString("") - io.Copy(buf, dec) - return buf.String() -} - -func sendKey(w io.Writer, key string) { - enc := newEncoder(w) - enc.Write([]byte(key)) - enc.Close() -} - -func (s *server) Serve(c net.Conn) { - var action int - fmt.Fscanf(c, "%d\n", &action) - - switch action { - case actionGet: - s.get(c, getKey(c)) - case actionRemove: - s.c.Remove(getKey(c)) - case actionExists: - s.exists(c, getKey(c)) - case actionClean: - s.c.Clean() - } -} - -func (s *server) exists(c net.Conn, key string) { - if s.c.Exists(key) { - fmt.Fprintf(c, "%d\n", 1) - } else { - fmt.Fprintf(c, "%d\n", 0) - } -} - -func (s *server) get(c net.Conn, key string) { - r, w, err := s.c.Get(key) - if err != nil { - return // handle this better - } - defer r.Close() - - if w != nil { - go func() { - fmt.Fprintf(c, "%d\n", 1) - io.Copy(w, newDecoder(c)) - w.Close() - }() - } else { - fmt.Fprintf(c, "%d\n", 0) - } - - enc := newEncoder(c) - io.Copy(enc, r) - enc.Close() -} - -type remote struct { - raddr string -} - -func (rmt *remote) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { - c, err := net.Dial("tcp", rmt.raddr) - if err != nil { - return nil, nil, err - } - fmt.Fprintf(c, "%d\n", actionGet) - sendKey(c, key) - - var i int - fmt.Fscanf(c, "%d\n", &i) - - var ch chan struct{} - - switch i { - case 0: - ch = make(chan struct{}) // close net.Conn on reader close - case 1: - ch = make(chan struct{}, 1) // two closes before net.Conn close - - w = &safeCloser{ - c: c, - ch: ch, - w: newEncoder(c), - } - default: - return nil, nil, errors.New("bad bad bad") - } - - r = &safeCloser{ - c: c, - ch: ch, - r: newDecoder(c), - } - - return r, w, nil -} - -type safeCloser struct { - c net.Conn - ch chan<- struct{} - r ReadAtCloser - w io.WriteCloser -} - -func (s *safeCloser) ReadAt(p []byte, off int64) (int, error) { - return s.r.ReadAt(p, off) -} -func (s *safeCloser) Read(p []byte) (int, error) { return s.r.Read(p) } -func (s *safeCloser) Write(p []byte) (int, error) { return s.w.Write(p) } - -// Close only closes the underlying connection when ch is full. -func (s *safeCloser) Close() (err error) { - if s.r != nil { - err = s.r.Close() - } else if s.w != nil { - err = s.w.Close() - } - - select { - case s.ch <- struct{}{}: - return err - default: - return s.c.Close() - } -} - -func (rmt *remote) Exists(key string) bool { - c, err := net.Dial("tcp", rmt.raddr) - if err != nil { - return false - } - fmt.Fprintf(c, "%d\n", actionExists) - sendKey(c, key) - var i int - fmt.Fscanf(c, "%d\n", &i) - return i == 1 -} - -func (rmt *remote) Remove(key string) error { - c, err := net.Dial("tcp", rmt.raddr) - if err != nil { - return err - } - fmt.Fprintf(c, "%d\n", actionRemove) - sendKey(c, key) - return nil -} - -func (rmt *remote) Clean() error { - c, err := net.Dial("tcp", rmt.raddr) - if err != nil { - return err - } - fmt.Fprintf(c, "%d\n", actionClean) - return nil -} diff --git a/libsq/core/ioz/fscache/stream.go b/libsq/core/ioz/fscache/stream.go deleted file mode 100644 index 9cccb2483..000000000 --- a/libsq/core/ioz/fscache/stream.go +++ /dev/null @@ -1,72 +0,0 @@ -package fscache - -import ( - "encoding/json" - "errors" - "io" -) - -type decoder interface { - Decode(interface{}) error -} - -type encoder interface { - Encode(interface{}) error -} - -type pktReader struct { - dec decoder -} - -type pktWriter struct { - enc encoder -} - -type packet struct { - Err int - Data []byte -} - -const eof = 1 - -func (t *pktReader) ReadAt(p []byte, off int64) (n int, err error) { - // TODO not implemented - return 0, errors.New("not implemented") -} - -func (t *pktReader) Read(p []byte) (int, error) { - var pkt packet - err := t.dec.Decode(&pkt) - if err != nil { - return 0, err - } - if pkt.Err == eof { - return 0, io.EOF - } - return copy(p, pkt.Data), nil -} - -func (t *pktReader) Close() error { - return nil -} - -func (t *pktWriter) Write(p []byte) (int, error) { - pkt := packet{Data: p} - err := t.enc.Encode(pkt) - if err != nil { - return 0, err - } - return len(p), nil -} - -func (t *pktWriter) Close() error { - return t.enc.Encode(packet{Err: eof}) -} - -func newEncoder(w io.Writer) io.WriteCloser { - return &pktWriter{enc: json.NewEncoder(w)} -} - -func newDecoder(r io.Reader) ReadAtCloser { - return &pktReader{dec: json.NewDecoder(r)} -} From a4ca0d691fb1169bdef03cbb8a2c1157ef5f311c Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 10:33:18 -0700 Subject: [PATCH 008/195] wip: adding fscache --- libsq/core/ioz/fscache/LICENSE | 22 + libsq/core/ioz/fscache/README.md | 93 ++++ libsq/core/ioz/fscache/distrib.go | 85 ++++ libsq/core/ioz/fscache/example_test.go | 69 +++ libsq/core/ioz/fscache/fileinfo.go | 52 +++ libsq/core/ioz/fscache/fs.go | 266 ++++++++++++ libsq/core/ioz/fscache/fscache.go | 373 ++++++++++++++++ libsq/core/ioz/fscache/fscache_test.go | 579 +++++++++++++++++++++++++ libsq/core/ioz/fscache/handler.go | 41 ++ libsq/core/ioz/fscache/haunter.go | 92 ++++ libsq/core/ioz/fscache/layers.go | 128 ++++++ libsq/core/ioz/fscache/lruhaunter.go | 137 ++++++ libsq/core/ioz/fscache/memfs.go | 147 +++++++ libsq/core/ioz/fscache/reaper.go | 37 ++ libsq/core/ioz/fscache/server.go | 206 +++++++++ libsq/core/ioz/fscache/stream.go | 72 +++ 16 files changed, 2399 insertions(+) create mode 100644 libsq/core/ioz/fscache/LICENSE create mode 100644 libsq/core/ioz/fscache/README.md create mode 100644 libsq/core/ioz/fscache/distrib.go create mode 100644 libsq/core/ioz/fscache/example_test.go create mode 100644 libsq/core/ioz/fscache/fileinfo.go create mode 100644 libsq/core/ioz/fscache/fs.go create mode 100644 libsq/core/ioz/fscache/fscache.go create mode 100644 libsq/core/ioz/fscache/fscache_test.go create mode 100644 libsq/core/ioz/fscache/handler.go create mode 100644 libsq/core/ioz/fscache/haunter.go create mode 100644 libsq/core/ioz/fscache/layers.go create mode 100644 libsq/core/ioz/fscache/lruhaunter.go create mode 100644 libsq/core/ioz/fscache/memfs.go create mode 100644 libsq/core/ioz/fscache/reaper.go create mode 100644 libsq/core/ioz/fscache/server.go create mode 100644 libsq/core/ioz/fscache/stream.go diff --git a/libsq/core/ioz/fscache/LICENSE b/libsq/core/ioz/fscache/LICENSE new file mode 100644 index 000000000..1e7b7cc09 --- /dev/null +++ b/libsq/core/ioz/fscache/LICENSE @@ -0,0 +1,22 @@ +The MIT License (MIT) + +Copyright (c) 2015 Dustin H + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/libsq/core/ioz/fscache/README.md b/libsq/core/ioz/fscache/README.md new file mode 100644 index 000000000..78b57ef35 --- /dev/null +++ b/libsq/core/ioz/fscache/README.md @@ -0,0 +1,93 @@ +fscache +========== + +[![GoDoc](https://godoc.org/github.com/djherbis/fscache?status.svg)](https://godoc.org/github.com/djherbis/fscache) +[![Release](https://img.shields.io/github/release/djherbis/fscache.svg)](https://github.com/djherbis/fscache/releases/latest) +[![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE.txt) +[![go test](https://github.com/djherbis/fscache/actions/workflows/go-test.yml/badge.svg)](https://github.com/djherbis/fscache/actions/workflows/go-test.yml) +[![Coverage Status](https://coveralls.io/repos/djherbis/fscache/badge.svg?branch=master)](https://coveralls.io/r/djherbis/fscache?branch=master) +[![Go Report Card](https://goreportcard.com/badge/github.com/djherbis/fscache)](https://goreportcard.com/report/github.com/djherbis/fscache) + +Usage +------------ +Streaming File Cache for #golang + +fscache allows multiple readers to read from a cache while its being written to. [blog post](https://djherbis.github.io/post/fscache/) + +Using the Cache directly: + +```go +package main + +import ( + "io" + "log" + "os" + "time" + + "gopkg.in/djherbis/fscache.v0" +) + +func main() { + + // create the cache, keys expire after 1 hour. + c, err := fscache.New("./cache", 0755, time.Hour) + if err != nil { + log.Fatal(err.Error()) + } + + // wipe the cache when done + defer c.Clean() + + // Get() and it's streams can be called concurrently but just for example: + for i := 0; i < 3; i++ { + r, w, err := c.Get("stream") + if err != nil { + log.Fatal(err.Error()) + } + + if w != nil { // a new stream, write to it. + go func(){ + w.Write([]byte("hello world\n")) + w.Close() + }() + } + + // the stream has started, read from it + io.Copy(os.Stdout, r) + r.Close() + } +} +``` + +A Caching Middle-ware: + +```go +package main + +import( + "net/http" + "time" + + "gopkg.in/djherbis/fscache.v0" +) + +func main(){ + c, err := fscache.New("./cache", 0700, 0) + if err != nil { + log.Fatal(err.Error()) + } + + handler := func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "%v: %s", time.Now(), "hello world") + } + + http.ListenAndServe(":8080", fscache.Handler(c, http.HandlerFunc(handler))) +} +``` + +Installation +------------ +```sh +go get gopkg.in/djherbis/fscache.v0 +``` diff --git a/libsq/core/ioz/fscache/distrib.go b/libsq/core/ioz/fscache/distrib.go new file mode 100644 index 000000000..60994cc58 --- /dev/null +++ b/libsq/core/ioz/fscache/distrib.go @@ -0,0 +1,85 @@ +package fscache + +import ( + "bytes" + "crypto/sha1" + "encoding/binary" + "io" +) + +// Distributor provides a way to partition keys into Caches. +type Distributor interface { + + // GetCache will always return the same Cache for the same key. + GetCache(key string) Cache + + // Clean should wipe all the caches this Distributor manages + Clean() error +} + +// stdDistribution distributes the keyspace evenly. +func stdDistribution(key string, n uint64) uint64 { + h := sha1.New() + io.WriteString(h, key) + buf := bytes.NewBuffer(h.Sum(nil)[:8]) + i, _ := binary.ReadUvarint(buf) + return i % n +} + +// NewDistributor returns a Distributor which evenly distributes the keyspace +// into the passed caches. +func NewDistributor(caches ...Cache) Distributor { + if len(caches) == 0 { + return nil + } + return &distrib{ + distribution: stdDistribution, + caches: caches, + size: uint64(len(caches)), + } +} + +type distrib struct { + distribution func(key string, n uint64) uint64 + caches []Cache + size uint64 +} + +func (d *distrib) GetCache(key string) Cache { + return d.caches[d.distribution(key, d.size)] +} + +// BUG(djherbis): Return an error if cleaning fails +func (d *distrib) Clean() error { + for _, c := range d.caches { + c.Clean() + } + return nil +} + +// NewPartition returns a Cache which uses the Caches defined by the passed Distributor. +func NewPartition(d Distributor) Cache { + return &partition{ + distributor: d, + } +} + +type partition struct { + distributor Distributor +} + +func (p *partition) Get(key string) (ReadAtCloser, io.WriteCloser, error) { + return p.distributor.GetCache(key).Get(key) +} + +func (p *partition) Remove(key string) error { + return p.distributor.GetCache(key).Remove(key) +} + +func (p *partition) Exists(key string) bool { + return p.distributor.GetCache(key).Exists(key) +} + +func (p *partition) Clean() error { + return p.distributor.Clean() +} diff --git a/libsq/core/ioz/fscache/example_test.go b/libsq/core/ioz/fscache/example_test.go new file mode 100644 index 000000000..5aa9e7266 --- /dev/null +++ b/libsq/core/ioz/fscache/example_test.go @@ -0,0 +1,69 @@ +package fscache + +import ( + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "time" +) + +func Example() { + // create the cache, keys expire after 1 hour. + c, err := New("./cache", 0755, time.Hour) + if err != nil { + log.Fatal(err.Error()) + } + + // wipe the cache when done + defer c.Clean() + + // Get() and it's streams can be called concurrently but just for example: + for i := 0; i < 3; i++ { + r, w, err := c.Get("stream") + if err != nil { + log.Fatal(err.Error()) + } + + if w != nil { // a new stream, write to it. + go func() { + w.Write([]byte("hello world\n")) + w.Close() + }() + } + + // the stream has started, read from it + io.Copy(os.Stdout, r) + r.Close() + } + // Output: + // hello world + // hello world + // hello world +} + +func ExampleHandler() { + c, err := New("./server", 0700, 0) + if err != nil { + log.Fatal(err.Error()) + } + defer c.Clean() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello Client") + }) + + ts := httptest.NewServer(Handler(c, handler)) + defer ts.Close() + + resp, err := http.Get(ts.URL) + if err != nil { + log.Fatal(err.Error()) + } + io.Copy(os.Stdout, resp.Body) + resp.Body.Close() + // Output: + // Hello Client +} diff --git a/libsq/core/ioz/fscache/fileinfo.go b/libsq/core/ioz/fscache/fileinfo.go new file mode 100644 index 000000000..445fcfd50 --- /dev/null +++ b/libsq/core/ioz/fscache/fileinfo.go @@ -0,0 +1,52 @@ +package fscache + +import ( + "os" + "time" +) + +// FileInfo is just a wrapper around os.FileInfo which includes atime. +type FileInfo struct { + os.FileInfo + Atime time.Time +} + +type fileInfo struct { + name string + size int64 + fileMode os.FileMode + isDir bool + sys interface{} + wt time.Time +} + +func (f *fileInfo) Name() string { + return f.name +} + +func (f *fileInfo) Size() int64 { + return f.size +} + +func (f *fileInfo) Mode() os.FileMode { + return f.fileMode +} + +func (f *fileInfo) ModTime() time.Time { + return f.wt +} + +func (f *fileInfo) IsDir() bool { + return f.isDir +} + +func (f *fileInfo) Sys() interface{} { + return f.sys +} + +// AccessTime returns the last time the file was read. +// It will be used to check expiry of a file, and must be concurrent safe +// with modifications to the FileSystem (writes, reads etc.) +func (f *FileInfo) AccessTime() time.Time { + return f.Atime +} diff --git a/libsq/core/ioz/fscache/fs.go b/libsq/core/ioz/fscache/fs.go new file mode 100644 index 000000000..dad018382 --- /dev/null +++ b/libsq/core/ioz/fscache/fs.go @@ -0,0 +1,266 @@ +package fscache + +import ( + "bytes" + "crypto/md5" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" + "time" + + "github.com/djherbis/atime" + "github.com/djherbis/stream" +) + +// FileSystemStater implementers can provide FileInfo data about a named resource. +type FileSystemStater interface { + // Stat takes a File.Name() and returns FileInfo interface + Stat(name string) (FileInfo, error) +} + +// FileSystem is used as the source for a Cache. +type FileSystem interface { + // Stream FileSystem + stream.FileSystem + + FileSystemStater + + // Reload should look through the FileSystem and call the supplied fn + // with the key/filename pairs that are found. + Reload(func(key, name string)) error + + // RemoveAll should empty the FileSystem of all files. + RemoveAll() error +} + +// StandardFS is an implemenation of FileSystem which writes to the os Filesystem. +type StandardFS struct { + root string + init func() error + + // EncodeKey takes a 'name' given to Create and converts it into a + // the Filename that should be used. It should return 'true' if + // DecodeKey can convert the returned string back to the original 'name' + // and false otherwise. + // This must be set before the first call to Create. + EncodeKey func(string) (string, bool) + + // DecodeKey should convert a given Filename into the original 'name' given to + // EncodeKey, and return true if this conversion was possible. Returning false + // will cause it to try and lookup a stored 'encodedName.key' file which holds + // the original name. + DecodeKey func(string) (string, bool) +} + +// IdentityCodeKey works as both an EncodeKey and a DecodeKey func, which just returns +// it's given argument and true. This is expected to be used when your FSCache +// uses SetKeyMapper to ensure its internal km(key) value is already a valid filename path. +func IdentityCodeKey(key string) (string, bool) { return key, true } + +// NewFs returns a FileSystem rooted at directory dir. +// Dir is created with perms if it doesn't exist. +// This also uses the default EncodeKey/DecodeKey functions B64ORMD5HashEncodeKey/B64DecodeKey. +func NewFs(dir string, mode os.FileMode) (*StandardFS, error) { + fs := &StandardFS{ + root: dir, + init: func() error { + return os.MkdirAll(dir, mode) + }, + EncodeKey: B64OrMD5HashEncodeKey, + DecodeKey: B64DecodeKey, + } + return fs, fs.init() +} + +// Reload looks through the dir given to NewFs and returns every key, name pair (Create(key) => name = File.Name()) +// that is managed by this FileSystem. +func (fs *StandardFS) Reload(add func(key, name string)) error { + files, err := ioutil.ReadDir(fs.root) + if err != nil { + return err + } + + addfiles := make(map[string]struct { + os.FileInfo + key string + }) + + for _, f := range files { + + if strings.HasSuffix(f.Name(), ".key") { + continue + } + + key, err := fs.getKey(f.Name()) + if err != nil { + fs.Remove(filepath.Join(fs.root, f.Name())) + continue + } + fi, ok := addfiles[key] + + if !ok || fi.ModTime().Before(f.ModTime()) { + if ok { + fs.Remove(fi.Name()) + } + addfiles[key] = struct { + os.FileInfo + key string + }{ + FileInfo: f, + key: key, + } + } else { + fs.Remove(f.Name()) + } + + } + + for _, f := range addfiles { + path, err := filepath.Abs(filepath.Join(fs.root, f.Name())) + if err != nil { + return err + } + add(f.key, path) + } + + return nil +} + +// Create creates a File for the given 'name', it may not use the given name on the +// os filesystem, that depends on the implementation of EncodeKey used. +func (fs *StandardFS) Create(name string) (stream.File, error) { + name, err := fs.makeName(name) + if err != nil { + return nil, err + } + return fs.create(name) +} + +func (fs *StandardFS) create(name string) (stream.File, error) { + return os.OpenFile(filepath.Join(fs.root, name), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) +} + +// Open opens a stream.File for the given File.Name() returned by Create(). +func (fs *StandardFS) Open(name string) (stream.File, error) { + return os.Open(name) +} + +// Remove removes a stream.File for the given File.Name() returned by Create(). +func (fs *StandardFS) Remove(name string) error { + os.Remove(fmt.Sprintf("%s.key", name)) + return os.Remove(name) +} + +// RemoveAll deletes all files in the directory managed by this StandardFS. +// Warning that if you put files in this directory that were not created by +// StandardFS they will also be deleted. +func (fs *StandardFS) RemoveAll() error { + if err := os.RemoveAll(fs.root); err != nil { + return err + } + return fs.init() +} + +// AccessTimes returns atime and mtime for the given File.Name() returned by Create(). +func (fs *StandardFS) AccessTimes(name string) (rt, wt time.Time, err error) { + fi, err := os.Stat(name) + if err != nil { + return rt, wt, err + } + return atime.Get(fi), fi.ModTime(), nil +} + +// Stat returns FileInfo for the given File.Name() returned by Create(). +func (fs *StandardFS) Stat(name string) (FileInfo, error) { + stat, err := os.Stat(name) + if err != nil { + return FileInfo{}, err + } + + return FileInfo{FileInfo: stat, Atime: atime.Get(stat)}, nil +} + +const ( + saltSize = 8 + salt = "xxxxxxxx" // this is only important for sizing now. + maxShort = 20 + shortPrefix = "s" + longPrefix = "l" +) + +func tob64(s string) string { + buf := bytes.NewBufferString("") + enc := base64.NewEncoder(base64.URLEncoding, buf) + enc.Write([]byte(s)) + enc.Close() + return buf.String() +} + +func fromb64(s string) string { + buf := bytes.NewBufferString(s) + dec := base64.NewDecoder(base64.URLEncoding, buf) + out := bytes.NewBufferString("") + io.Copy(out, dec) + return out.String() +} + +// B64OrMD5HashEncodeKey converts a given key into a filesystem name-safe string +// and returns true iff it can be reversed with B64DecodeKey. +func B64OrMD5HashEncodeKey(key string) (string, bool) { + b64key := tob64(key) + // short name + if len(b64key) < maxShort { + return fmt.Sprintf("%s%s%s", shortPrefix, salt, b64key), true + } + + // long name + hash := md5.Sum([]byte(key)) + return fmt.Sprintf("%s%s%x", longPrefix, salt, hash[:]), false +} + +func (fs *StandardFS) makeName(key string) (string, error) { + name, decodable := fs.EncodeKey(key) + if decodable { + return name, nil + } + + // Name is not decodeable, store it. + f, err := fs.create(fmt.Sprintf("%s.key", name)) + if err != nil { + return "", err + } + _, err = f.Write([]byte(key)) + f.Close() + return name, err +} + +// B64DecodeKey converts a string y into x st. y, ok = B64OrMD5HashEncodeKey(x), and ok = true. +// Basically it should reverse B64OrMD5HashEncodeKey if B64OrMD5HashEncodeKey returned true. +func B64DecodeKey(name string) (string, bool) { + if strings.HasPrefix(name, shortPrefix) { + return fromb64(strings.TrimPrefix(name, shortPrefix)[saltSize:]), true + } + return "", false +} + +func (fs *StandardFS) getKey(name string) (string, error) { + if key, ok := fs.DecodeKey(name); ok { + return key, nil + } + + // long name + f, err := fs.Open(filepath.Join(fs.root, fmt.Sprintf("%s.key", name))) + if err != nil { + return "", err + } + defer f.Close() + key, err := ioutil.ReadAll(f) + if err != nil { + return "", err + } + return string(key), nil +} diff --git a/libsq/core/ioz/fscache/fscache.go b/libsq/core/ioz/fscache/fscache.go new file mode 100644 index 000000000..6de40a3b8 --- /dev/null +++ b/libsq/core/ioz/fscache/fscache.go @@ -0,0 +1,373 @@ +package fscache + +import ( + "fmt" + "io" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/djherbis/stream" +) + +// Cache works like a concurrent-safe map for streams. +type Cache interface { + // Get manages access to the streams in the cache. + // If the key does not exist, w != nil and you can start writing to the stream. + // If the key does exist, w == nil. + // r will always be non-nil as long as err == nil and you must close r when you're done reading. + // Get can be called concurrently, and writing and reading is concurrent safe. + Get(key string) (ReadAtCloser, io.WriteCloser, error) + + // Remove deletes the stream from the cache, blocking until the underlying + // file can be deleted (all active streams finish with it). + // It is safe to call Remove concurrently with Get. + Remove(key string) error + + // Exists checks if a key is in the cache. + // It is safe to call Exists concurrently with Get. + Exists(key string) bool + + // Clean will empty the cache and delete the cache folder. + // Clean is not safe to call while streams are being read/written. + Clean() error +} + +// FSCache is a Cache which uses a Filesystem to read/write cached data. +type FSCache struct { + mu sync.RWMutex + files map[string]fileStream + km func(string) string + fs FileSystem + haunter Haunter +} + +// SetKeyMapper will use the given function to transform any given Cache key into the result of km(key). +// This means that internally, the cache will only track km(key), and forget the original key. The consequences +// of this are that Enumerate will return km(key) instead of key, and Filesystem will give km(key) to Create +// and expect Reload() to return km(key). +// The purpose of this function is so that the internally managed key can be converted to a string that is +// allowed as a filesystem path. +func (c *FSCache) SetKeyMapper(km func(string) string) *FSCache { + c.mu.Lock() + defer c.mu.Unlock() + c.km = km + return c +} + +func (c *FSCache) mapKey(key string) string { + if c.km == nil { + return key + } + return c.km(key) +} + +// ReadAtCloser is an io.ReadCloser, and an io.ReaderAt. It supports both so that Range +// Requests are possible. +type ReadAtCloser interface { + io.ReadCloser + io.ReaderAt +} + +type fileStream interface { + next() (*CacheReader, error) + InUse() bool + io.WriteCloser + remove() error + Name() string +} + +// New creates a new Cache using NewFs(dir, perms). +// expiry is the duration after which an un-accessed key will be removed from +// the cache, a zero value expiro means never expire. +func New(dir string, perms os.FileMode, expiry time.Duration) (*FSCache, error) { + fs, err := NewFs(dir, perms) + if err != nil { + return nil, err + } + var grim Reaper + if expiry > 0 { + grim = &reaper{ + expiry: expiry, + period: expiry, + } + } + return NewCache(fs, grim) +} + +// NewCache creates a new Cache based on FileSystem fs. +// fs.Files() are loaded using the name they were created with as a key. +// Reaper is used to determine when files expire, nil means never expire. +func NewCache(fs FileSystem, grim Reaper) (*FSCache, error) { + if grim != nil { + return NewCacheWithHaunter(fs, NewReaperHaunterStrategy(grim)) + } + + return NewCacheWithHaunter(fs, nil) +} + +// NewCacheWithHaunter create a new Cache based on FileSystem fs. +// fs.Files() are loaded using the name they were created with as a key. +// Haunter is used to determine when files expire, nil means never expire. +func NewCacheWithHaunter(fs FileSystem, haunter Haunter) (*FSCache, error) { + c := &FSCache{ + files: make(map[string]fileStream), + haunter: haunter, + fs: fs, + } + err := c.load() + if err != nil { + return nil, err + } + if haunter != nil { + c.scheduleHaunt() + } + + return c, nil +} + +func (c *FSCache) scheduleHaunt() { + c.haunt() + time.AfterFunc(c.haunter.Next(), c.scheduleHaunt) +} + +func (c *FSCache) haunt() { + c.mu.Lock() + defer c.mu.Unlock() + + c.haunter.Haunt(&accessor{c: c}) +} + +func (c *FSCache) load() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.fs.Reload(func(key, name string) { + c.files[key] = c.oldFile(name) + }) +} + +// Exists returns true iff this key is in the Cache (may not be finished streaming). +func (c *FSCache) Exists(key string) bool { + c.mu.RLock() + defer c.mu.RUnlock() + _, ok := c.files[c.mapKey(key)] + return ok +} + +// Get obtains a ReadAtCloser for the given key, and may return a WriteCloser to write the original cache data +// if this is a cache-miss. +func (c *FSCache) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { + c.mu.RLock() + key = c.mapKey(key) + f, ok := c.files[key] + if ok { + r, err = f.next() + c.mu.RUnlock() + return r, nil, err + } + c.mu.RUnlock() + + c.mu.Lock() + defer c.mu.Unlock() + + f, ok = c.files[key] + if ok { + r, err = f.next() + return r, nil, err + } + + f, err = c.newFile(key) + if err != nil { + return nil, nil, err + } + + r, err = f.next() + if err != nil { + f.Close() + c.fs.Remove(f.Name()) + return nil, nil, err + } + + c.files[key] = f + + return r, f, err +} + +// Remove removes the specified key from the cache. +func (c *FSCache) Remove(key string) error { + c.mu.Lock() + key = c.mapKey(key) + f, ok := c.files[key] + delete(c.files, key) + c.mu.Unlock() + + if ok { + return f.remove() + } + return nil +} + +// Clean resets the cache removing all keys and data. +func (c *FSCache) Clean() error { + c.mu.Lock() + defer c.mu.Unlock() + c.files = make(map[string]fileStream) + return c.fs.RemoveAll() +} + +type accessor struct { + c *FSCache +} + +func (a *accessor) Stat(name string) (FileInfo, error) { + return a.c.fs.Stat(name) +} + +func (a *accessor) EnumerateEntries(enumerator func(key string, e Entry) bool) { + for k, f := range a.c.files { + if !enumerator(k, Entry{name: f.Name(), inUse: f.InUse()}) { + break + } + } +} + +func (a *accessor) RemoveFile(key string) { + key = a.c.mapKey(key) + f, ok := a.c.files[key] + delete(a.c.files, key) + if ok { + a.c.fs.Remove(f.Name()) + } +} + +type cachedFile struct { + handleCounter + stream *stream.Stream +} + +func (c *FSCache) newFile(name string) (fileStream, error) { + s, err := stream.NewStream(name, c.fs) + if err != nil { + return nil, err + } + cf := &cachedFile{ + stream: s, + } + cf.inc() + return cf, nil +} + +func (c *FSCache) oldFile(name string) fileStream { + return &reloadedFile{ + fs: c.fs, + name: name, + } +} + +type reloadedFile struct { + handleCounter + fs FileSystem + name string + io.WriteCloser // nop Write & Close methods. will never be called. +} + +func (f *reloadedFile) Name() string { return f.name } + +func (f *reloadedFile) remove() error { + f.waitUntilFree() + return f.fs.Remove(f.name) +} + +func (f *reloadedFile) next() (*CacheReader, error) { + r, err := f.fs.Open(f.name) + if err == nil { + f.inc() + } + return &CacheReader{ + ReadAtCloser: r, + cnt: &f.handleCounter, + }, err +} + +func (f *cachedFile) Name() string { return f.stream.Name() } + +func (f *cachedFile) remove() error { return f.stream.Remove() } + +func (f *cachedFile) next() (*CacheReader, error) { + reader, err := f.stream.NextReader() + if err != nil { + return nil, err + } + f.inc() + return &CacheReader{ + ReadAtCloser: reader, + cnt: &f.handleCounter, + }, nil +} + +func (f *cachedFile) Write(p []byte) (int, error) { + return f.stream.Write(p) +} + +func (f *cachedFile) Close() error { + defer f.dec() + return f.stream.Close() +} + +// CacheReader is a ReadAtCloser for a Cache key that also tracks open readers. +type CacheReader struct { + ReadAtCloser + cnt *handleCounter +} + +// Close frees the underlying ReadAtCloser and updates the open reader counter. +func (r *CacheReader) Close() error { + defer r.cnt.dec() + return r.ReadAtCloser.Close() +} + +// Size returns the current size of the stream being read, the boolean it +// returns is true iff the stream is done being written (otherwise Size may change). +// An error is returned if the Size fails to be computed or is not supported +// by the underlying filesystem. +func (r *CacheReader) Size() (int64, bool, error) { + switch v := r.ReadAtCloser.(type) { + case *stream.Reader: + size, done := v.Size() + return size, done, nil + + case interface{ Stat() (os.FileInfo, error) }: + fi, err := v.Stat() + if err != nil { + return 0, false, err + } + return fi.Size(), true, nil + + default: + return 0, false, fmt.Errorf("reader does not support stat") + } +} + +type handleCounter struct { + cnt int64 + grp sync.WaitGroup +} + +func (h *handleCounter) inc() { + h.grp.Add(1) + atomic.AddInt64(&h.cnt, 1) +} + +func (h *handleCounter) dec() { + atomic.AddInt64(&h.cnt, -1) + h.grp.Done() +} + +func (h *handleCounter) InUse() bool { + return atomic.LoadInt64(&h.cnt) > 0 +} + +func (h *handleCounter) waitUntilFree() { + h.grp.Wait() +} diff --git a/libsq/core/ioz/fscache/fscache_test.go b/libsq/core/ioz/fscache/fscache_test.go new file mode 100644 index 000000000..de125299a --- /dev/null +++ b/libsq/core/ioz/fscache/fscache_test.go @@ -0,0 +1,579 @@ +package fscache + +import ( + "bytes" + "crypto/md5" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" +) + +func createFile(name string) (*os.File, error) { + return os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) +} + +func init() { + c, _ := NewCache(NewMemFs(), nil) + go ListenAndServe(c, "localhost:10000") +} + +func testCaches(t *testing.T, run func(c Cache)) { + c, err := New("./cache", 0700, 1*time.Hour) + if err != nil { + t.Error(err.Error()) + return + } + run(c) + + c, err = NewCache(NewMemFs(), NewReaper(time.Hour, time.Hour)) + if err != nil { + t.Error(err.Error()) + return + } + run(c) + + c2, _ := NewCache(NewMemFs(), nil) + run(NewPartition(NewDistributor(c, c2))) + + lc := NewLayered(c, c2) + run(lc) + + rc := NewRemote("localhost:10000") + run(rc) + + fs, _ := NewFs("./cachex", 0700) + fs.EncodeKey = IdentityCodeKey + fs.DecodeKey = IdentityCodeKey + ck, _ := NewCache(fs, NewReaper(time.Hour, time.Hour)) + ck.SetKeyMapper(func(key string) string { + name, _ := B64OrMD5HashEncodeKey(key) + return name + }) + run(ck) +} + +func TestHandler(t *testing.T) { + testCaches(t, func(c Cache) { + defer c.Clean() + ts := httptest.NewServer(Handler(c, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello Client") + }))) + defer ts.Close() + + for i := 0; i < 3; i++ { + res, err := http.Get(ts.URL) + if err != nil { + t.Error(err.Error()) + t.FailNow() + } + p, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if !bytes.Equal([]byte("Hello Client\n"), p) { + t.Errorf("unexpected response %s", string(p)) + } + } + }) +} + +func TestMemFs(t *testing.T) { + fs := NewMemFs() + fs.Reload(func(key, name string) {}) // nop + if _, err := fs.Open("test"); err == nil { + t.Errorf("stream shouldn't exist") + } + fs.Remove("test") + + f, err := fs.Create("test") + if err != nil { + t.Errorf("failed to create test") + } + f.Write([]byte("hello")) + f.Close() + + r, err := fs.Open("test") + if err != nil { + t.Errorf("failed Open: %v", err) + } + p, err := ioutil.ReadAll(r) + if err != nil { + t.Errorf("failed ioutil.ReadAll: %v", err) + } + r.Close() + if !bytes.Equal(p, []byte("hello")) { + t.Errorf("expected hello, got %s", string(p)) + } + fs.RemoveAll() +} + +func TestLoadCleanup1(t *testing.T) { + os.Mkdir("./cache6", 0700) + f, err := createFile(filepath.Join("./cache6", "s11111111"+tob64("test"))) + if err != nil { + t.Error(err.Error()) + } + f.Close() + <-time.After(time.Second) + f, err = createFile(filepath.Join("./cache6", "s22222222"+tob64("test"))) + if err != nil { + t.Error(err.Error()) + } + f.Close() + + c, err := New("./cache6", 0700, 0) + if err != nil { + t.Error(err.Error()) + return + } + defer c.Clean() + + if !c.Exists("test") { + t.Errorf("expected test to exist") + } +} + +const longString = ` + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 + 0123456789 0123456789 +` + +func TestLoadCleanup2(t *testing.T) { + hash := md5.Sum([]byte(longString)) + name2 := fmt.Sprintf("%s%s%x", longPrefix, "22222222", hash[:]) + name1 := fmt.Sprintf("%s%s%x", longPrefix, "11111111", hash[:]) + + os.Mkdir("./cache7", 0700) + f, err := createFile(filepath.Join("./cache7", name2)) + if err != nil { + t.Error(err.Error()) + } + f.Close() + f, err = createFile(filepath.Join("./cache7", fmt.Sprintf("%s.key", name2))) + if err != nil { + t.Error(err.Error()) + } + f.Write([]byte(longString)) + f.Close() + <-time.After(time.Second) + f, err = createFile(filepath.Join("./cache7", name1)) + if err != nil { + t.Error(err.Error()) + } + f.Close() + f, err = createFile(filepath.Join("./cache7", fmt.Sprintf("%s.key", name1))) + if err != nil { + t.Error(err.Error()) + } + f.Write([]byte(longString)) + f.Close() + + c, err := New("./cache7", 0700, 0) + if err != nil { + t.Error(err.Error()) + return + } + defer c.Clean() + + if !c.Exists(longString) { + t.Errorf("expected test to exist") + } +} + +func TestReload(t *testing.T) { + dir, err := ioutil.TempDir("", "cache5") + if err != nil { + t.Fatalf("Failed to create TempDir: %v", err) + } + c, err := New(dir, 0700, 0) + if err != nil { + t.Error(err.Error()) + return + } + r, w, err := c.Get("stream") + if err != nil { + t.Error(err.Error()) + return + } + r.Close() + data := []byte("hello world\n") + w.Write(data) + w.Close() + + nc, err := New(dir, 0700, 0) + if err != nil { + t.Error(err.Error()) + return + } + defer nc.Clean() + + if !nc.Exists("stream") { + t.Fatalf("expected stream to be reloaded") + } + + r, w, err = nc.Get("stream") + if err != nil { + t.Fatal(err) + } + if w != nil { + t.Fatal("expected reloaded stream to not be writable") + } + + cr, ok := r.(*CacheReader) + if !ok { + t.Fatalf("CacheReader should be supported by a normal FS") + } + size, closed, err := cr.Size() + if err != nil { + t.Fatalf("Failed to get Size: %v", err) + } + if !closed { + t.Errorf("Expected stream to be closed.") + } + if size != int64(len(data)) { + t.Errorf("Expected size to be %v, but got %v", len(data), size) + } + + r.Close() + nc.Remove("stream") + if nc.Exists("stream") { + t.Errorf("expected stream to be removed") + } +} + +func TestLRUHaunterMaxItems(t *testing.T) { + + fs, err := NewFs("./cache1", 0700) + if err != nil { + t.Error(err.Error()) + t.FailNow() + } + + c, err := NewCacheWithHaunter(fs, NewLRUHaunterStrategy(NewLRUHaunter(3, 0, 400*time.Millisecond))) + + if err != nil { + t.Error(err.Error()) + return + } + defer c.Clean() + + for i := 0; i < 5; i++ { + name := fmt.Sprintf("stream-%v", i) + r, w, _ := c.Get(name) + w.Write([]byte("hello")) + w.Close() + io.Copy(ioutil.Discard, r) + + if !c.Exists(name) { + t.Errorf(name + " should exist") + } + + <-time.After(10 * time.Millisecond) + + err := r.Close() + if err != nil { + t.Error(err) + } + } + + <-time.After(400 * time.Millisecond) + + if c.Exists("stream-0") { + t.Errorf("stream-0 should have been scrubbed") + } + + if c.Exists("stream-1") { + t.Errorf("stream-1 should have been scrubbed") + } + + files, err := ioutil.ReadDir("./cache1") + if err != nil { + t.Error(err.Error()) + return + } + + if len(files) != 3 { + t.Errorf("expected 3 items in directory") + } +} + +func TestLRUHaunterMaxSize(t *testing.T) { + + fs, err := NewFs("./cache1", 0700) + if err != nil { + t.Error(err.Error()) + t.FailNow() + } + + c, err := NewCacheWithHaunter(fs, NewLRUHaunterStrategy(NewLRUHaunter(0, 24, 400*time.Millisecond))) + + if err != nil { + t.Error(err.Error()) + return + } + defer c.Clean() + + for i := 0; i < 5; i++ { + name := fmt.Sprintf("stream-%v", i) + r, w, _ := c.Get(name) + w.Write([]byte("hello")) + w.Close() + io.Copy(ioutil.Discard, r) + + if !c.Exists(name) { + t.Errorf(name + " should exist") + } + + <-time.After(10 * time.Millisecond) + + err := r.Close() + if err != nil { + t.Error(err) + } + } + + <-time.After(400 * time.Millisecond) + + if c.Exists("stream-0") { + t.Errorf("stream-0 should have been scrubbed") + } + + files, err := ioutil.ReadDir("./cache1") + if err != nil { + t.Error(err.Error()) + return + } + + if len(files) != 4 { + t.Errorf("expected 4 items in directory") + } +} + +func TestReaper(t *testing.T) { + fs, err := NewFs("./cache1", 0700) + if err != nil { + t.Error(err.Error()) + t.FailNow() + } + + c, err := NewCache(fs, NewReaper(0*time.Second, 100*time.Millisecond)) + if err != nil { + t.Fatal(err) + } + defer c.Clean() + + r, w, err := c.Get("stream") + if err != nil { + t.Fatal(err) + } + w.Write([]byte("hello")) + w.Close() + io.Copy(ioutil.Discard, r) + + if !c.Exists("stream") { + t.Errorf("stream should exist") + } + + <-time.After(200 * time.Millisecond) + + if !c.Exists("stream") { + t.Errorf("a file expired while in use, fail!") + } + r.Close() + + <-time.After(200 * time.Millisecond) + + if c.Exists("stream") { + t.Errorf("stream should have been reaped") + } + + files, err := ioutil.ReadDir("./cache1") + if err != nil { + t.Error(err.Error()) + return + } + + if len(files) > 0 { + t.Errorf("expected empty directory") + } +} + +func TestReaperNoExpire(t *testing.T) { + testCaches(t, func(c Cache) { + defer c.Clean() + r, w, err := c.Get("stream") + if err != nil { + t.Error(err.Error()) + t.FailNow() + } + w.Write([]byte("hello")) + w.Close() + io.Copy(ioutil.Discard, r) + r.Close() + + if !c.Exists("stream") { + t.Errorf("stream should exist") + } + + if lc, ok := c.(*FSCache); ok { + lc.haunt() + if !c.Exists("stream") { + t.Errorf("stream shouldn't have been reaped") + } + } + }) +} + +func TestSanity(t *testing.T) { + atLeastOneCacheReader := false + testCaches(t, func(c Cache) { + defer c.Clean() + + r, w, err := c.Get(longString) + if err != nil { + t.Error(err.Error()) + return + } + defer r.Close() + + want := []byte("hello world\n") + first := want[:5] + w.Write(first) + + cr, ok := r.(*CacheReader) + if ok { + atLeastOneCacheReader = true + size, closed, _ := cr.Size() + if closed { + t.Errorf("Expected stream to be open.") + } + if size != int64(len(first)) { + t.Errorf("Expected size to be %v, but got %v", len(first), size) + } + } + + second := want[5:] + w.Write(second) + + if ok { + atLeastOneCacheReader = true + size, closed, _ := cr.Size() + if closed { + t.Errorf("Expected stream to be open.") + } + if size != int64(len(want)) { + t.Errorf("Expected size to be %v, but got %v", len(want), size) + } + } + + w.Close() + + if ok { + atLeastOneCacheReader = true + size, closed, _ := cr.Size() + if !closed { + t.Errorf("Expected stream to be closed.") + } + if size != int64(len(want)) { + t.Errorf("Expected size to be %v, but got %v", len(want), size) + } + } + + buf := bytes.NewBuffer(nil) + _, err = io.Copy(buf, r) + if err != nil { + t.Error(err.Error()) + return + } + if !bytes.Equal(buf.Bytes(), want) { + t.Errorf("unexpected output %s", buf.Bytes()) + } + }) + if !atLeastOneCacheReader { + t.Errorf("None of the cache tests covered CacheReader!") + } +} + +func TestConcurrent(t *testing.T) { + testCaches(t, func(c Cache) { + defer c.Clean() + + r, w, err := c.Get("stream") + r.Close() + if err != nil { + t.Error(err.Error()) + return + } + go func() { + w.Write([]byte("hello")) + <-time.After(100 * time.Millisecond) + w.Write([]byte("world")) + w.Close() + }() + + if c.Exists("stream") { + r, _, err := c.Get("stream") + if err != nil { + t.Error(err.Error()) + return + } + buf := bytes.NewBuffer(nil) + io.Copy(buf, r) + r.Close() + if !bytes.Equal(buf.Bytes(), []byte("helloworld")) { + t.Errorf("unexpected output %s", buf.Bytes()) + } + } + }) +} + +func TestReuse(t *testing.T) { + testCaches(t, func(c Cache) { + for i := 0; i < 10; i++ { + r, w, err := c.Get(longString) + if err != nil { + t.Error(err.Error()) + return + } + + data := fmt.Sprintf("hello %d", i) + + if w != nil { + w.Write([]byte(data)) + w.Close() + } + + check(t, r, data) + r.Close() + + c.Clean() + } + }) +} + +func check(t *testing.T, r io.Reader, data string) { + buf := bytes.NewBuffer(nil) + _, err := io.Copy(buf, r) + if err != nil { + t.Error(err.Error()) + return + } + if !bytes.Equal(buf.Bytes(), []byte(data)) { + t.Errorf("unexpected output %q, want %q", buf.String(), data) + } +} diff --git a/libsq/core/ioz/fscache/handler.go b/libsq/core/ioz/fscache/handler.go new file mode 100644 index 000000000..8df85400c --- /dev/null +++ b/libsq/core/ioz/fscache/handler.go @@ -0,0 +1,41 @@ +package fscache + +import ( + "io" + "net/http" +) + +// Handler is a caching middle-ware for http Handlers. +// It responds to http requests via the passed http.Handler, and caches the response +// using the passed cache. The cache key for the request is the req.URL.String(). +// Note: It does not cache http headers. It is more efficient to set them yourself. +func Handler(c Cache, h http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + url := req.URL.String() + r, w, err := c.Get(url) + if err != nil { + h.ServeHTTP(rw, req) + return + } + defer r.Close() + if w != nil { + go func() { + defer w.Close() + h.ServeHTTP(&respWrapper{ + ResponseWriter: rw, + Writer: w, + }, req) + }() + } + io.Copy(rw, r) + }) +} + +type respWrapper struct { + http.ResponseWriter + io.Writer +} + +func (r *respWrapper) Write(p []byte) (int, error) { + return r.Writer.Write(p) +} diff --git a/libsq/core/ioz/fscache/haunter.go b/libsq/core/ioz/fscache/haunter.go new file mode 100644 index 000000000..a8d038ce9 --- /dev/null +++ b/libsq/core/ioz/fscache/haunter.go @@ -0,0 +1,92 @@ +package fscache + +import ( + "time" +) + +// Entry represents a cached item. +type Entry struct { + name string + inUse bool +} + +// InUse returns if this Cache entry is in use. +func (e *Entry) InUse() bool { + return e.inUse +} + +// Name returns the File.Name() of this entry. +func (e *Entry) Name() string { + return e.name +} + +// CacheAccessor implementors provide ways to observe and interact with +// the cached entries, mainly used for cache-eviction. +type CacheAccessor interface { + FileSystemStater + EnumerateEntries(enumerator func(key string, e Entry) bool) + RemoveFile(key string) +} + +// Haunter implementors are used to perform cache-eviction (Next is how long to wait +// until next evication, Haunt preforms the eviction). +type Haunter interface { + Haunt(c CacheAccessor) + Next() time.Duration +} + +type reaperHaunterStrategy struct { + reaper Reaper +} + +type lruHaunterStrategy struct { + haunter LRUHaunter +} + +// NewLRUHaunterStrategy returns a simple scheduleHaunt which provides an implementation LRUHaunter strategy +func NewLRUHaunterStrategy(haunter LRUHaunter) Haunter { + return &lruHaunterStrategy{ + haunter: haunter, + } +} + +func (h *lruHaunterStrategy) Haunt(c CacheAccessor) { + for _, key := range h.haunter.Scrub(c) { + c.RemoveFile(key) + } + +} + +func (h *lruHaunterStrategy) Next() time.Duration { + return h.haunter.Next() +} + +// NewReaperHaunterStrategy returns a simple scheduleHaunt which provides an implementation Reaper strategy +func NewReaperHaunterStrategy(reaper Reaper) Haunter { + return &reaperHaunterStrategy{ + reaper: reaper, + } +} + +func (h *reaperHaunterStrategy) Haunt(c CacheAccessor) { + c.EnumerateEntries(func(key string, e Entry) bool { + if e.InUse() { + return true + } + + fileInfo, err := c.Stat(e.Name()) + if err != nil { + return true + } + + if h.reaper.Reap(key, fileInfo.AccessTime(), fileInfo.ModTime()) { + c.RemoveFile(key) + } + + return true + }) +} + +func (h *reaperHaunterStrategy) Next() time.Duration { + return h.reaper.Next() +} diff --git a/libsq/core/ioz/fscache/layers.go b/libsq/core/ioz/fscache/layers.go new file mode 100644 index 000000000..b0b283106 --- /dev/null +++ b/libsq/core/ioz/fscache/layers.go @@ -0,0 +1,128 @@ +package fscache + +import ( + "errors" + "io" + "sync" +) + +type layeredCache struct { + layers []Cache +} + +// NewLayered returns a Cache which stores its data in all the passed +// caches, when a key is requested it is loaded into all the caches above the first hit. +func NewLayered(caches ...Cache) Cache { + return &layeredCache{layers: caches} +} + +func (l *layeredCache) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { + var last ReadAtCloser + var writers []io.WriteCloser + + for i, layer := range l.layers { + r, w, err = layer.Get(key) + if err != nil { + if len(writers) > 0 { + last.Close() + multiWC(writers...).Close() + } + return nil, nil, err + } + + // hit + if w == nil { + if len(writers) > 0 { + go func(r io.ReadCloser) { + wc := multiWC(writers...) + defer r.Close() + defer wc.Close() + io.Copy(wc, r) + }(r) + return last, nil, nil + } + return r, nil, nil + } + + // miss + writers = append(writers, w) + + if i == len(l.layers)-1 { + if last != nil { + last.Close() + } + return r, multiWC(writers...), nil + } + + if last != nil { + last.Close() + } + last = r + } + + return nil, nil, errors.New("no caches") +} + +func (l *layeredCache) Remove(key string) error { + var grp sync.WaitGroup + // walk upwards so that lower layers don't + // restore upper layers on Get() + for i := len(l.layers) - 1; i >= 0; i-- { + grp.Add(1) + go func(layer Cache) { + defer grp.Done() + layer.Remove(key) + }(l.layers[i]) + } + grp.Wait() + return nil +} + +func (l *layeredCache) Exists(key string) bool { + for _, layer := range l.layers { + if layer.Exists(key) { + return true + } + } + return false +} + +func (l *layeredCache) Clean() (error) { + for _, layer := range l.layers { + if err := layer.Clean(); err != nil { + return err + } + } + return nil +} + +func multiWC(wc ...io.WriteCloser) io.WriteCloser { + if len(wc) == 0 { + return nil + } + + return &multiWriteCloser{ + writers: wc, + } +} + +type multiWriteCloser struct { + writers []io.WriteCloser +} + +func (t *multiWriteCloser) Write(p []byte) (n int, err error) { + for _, w := range t.writers { + n, err = w.Write(p) + if err != nil { + return + } + } + return len(p), nil +} + +func (t *multiWriteCloser) Close() error { + for _, w := range t.writers { + w.Close() + } + return nil +} diff --git a/libsq/core/ioz/fscache/lruhaunter.go b/libsq/core/ioz/fscache/lruhaunter.go new file mode 100644 index 000000000..7b90ef3a7 --- /dev/null +++ b/libsq/core/ioz/fscache/lruhaunter.go @@ -0,0 +1,137 @@ +package fscache + +import ( + "sort" + "time" +) + +type lruHaunterKV struct { + Key string + Value Entry +} + +// LRUHaunter is used to control when there are too many streams +// or the size of the streams is too big. +// It is called once right after loading, and then it is run +// again after every Next() period of time. +type LRUHaunter interface { + // Returns the amount of time to wait before the next scheduled Reaping. + Next() time.Duration + + // Given a CacheAccessor, return keys to reap list. + Scrub(c CacheAccessor) []string +} + +// NewLRUHaunter returns a simple haunter which runs every "period" +// and scrubs older files when the total file size is over maxSize or +// total item count is over maxItems. +// If maxItems or maxSize are 0, they won't be checked +func NewLRUHaunter(maxItems int, maxSize int64, period time.Duration) LRUHaunter { + return &lruHaunter{ + period: period, + maxItems: maxItems, + maxSize: maxSize, + } +} + +type lruHaunter struct { + period time.Duration + maxItems int + maxSize int64 +} + +func (j *lruHaunter) Next() time.Duration { + return j.period +} + +func (j *lruHaunter) Scrub(c CacheAccessor) (keysToReap []string) { + var count int + var size int64 + var okFiles []lruHaunterKV + + c.EnumerateEntries(func(key string, e Entry) bool { + if e.InUse() { + return true + } + + fileInfo, err := c.Stat(e.Name()) + if err != nil { + return true + } + + count++ + size = size + fileInfo.Size() + okFiles = append(okFiles, lruHaunterKV{ + Key: key, + Value: e, + }) + + return true + }) + + sort.Slice(okFiles, func(i, j int) bool { + iFileInfo, err := c.Stat(okFiles[i].Value.Name()) + if err != nil { + return false + } + + iLastRead := iFileInfo.AccessTime() + + jFileInfo, err := c.Stat(okFiles[j].Value.Name()) + if err != nil { + return false + } + + jLastRead := jFileInfo.AccessTime() + + return iLastRead.Before(jLastRead) + }) + + collectKeysToReapFn := func() bool { + var key *string + var err error + key, count, size, err = j.removeFirst(c, &okFiles, count, size) + if err != nil { + return false + } + if key != nil { + keysToReap = append(keysToReap, *key) + } + + return true + } + + if j.maxItems > 0 { + for count > j.maxItems { + if !collectKeysToReapFn() { + break + } + } + } + + if j.maxSize > 0 { + for size > j.maxSize { + if !collectKeysToReapFn() { + break + } + } + } + + return keysToReap +} + +func (j *lruHaunter) removeFirst(fsStater FileSystemStater, items *[]lruHaunterKV, count int, size int64) (*string, int, int64, error) { + var f lruHaunterKV + + f, *items = (*items)[0], (*items)[1:] + + fileInfo, err := fsStater.Stat(f.Value.Name()) + if err != nil { + return nil, count, size, err + } + + count-- + size = size - fileInfo.Size() + + return &f.Key, count, size, nil +} diff --git a/libsq/core/ioz/fscache/memfs.go b/libsq/core/ioz/fscache/memfs.go new file mode 100644 index 000000000..ddfb92cf3 --- /dev/null +++ b/libsq/core/ioz/fscache/memfs.go @@ -0,0 +1,147 @@ +package fscache + +import ( + "bytes" + "errors" + "io" + "os" + "sync" + "time" + + "github.com/djherbis/stream" +) + +type memFS struct { + mu sync.RWMutex + files map[string]*memFile +} + +// NewMemFs creates an in-memory FileSystem. +// It does not support persistence (Reload is a nop). +func NewMemFs() FileSystem { + return &memFS{ + files: make(map[string]*memFile), + } +} + +func (fs *memFS) Stat(name string) (FileInfo, error) { + fs.mu.RLock() + defer fs.mu.RUnlock() + f, ok := fs.files[name] + if !ok { + return FileInfo{}, errors.New("file has not been read") + } + + size := int64(len(f.Bytes())) + + return FileInfo{ + FileInfo: &fileInfo{ + name: name, + size: size, + fileMode: os.ModeIrregular, + isDir: false, + sys: nil, + wt: f.wt, + }, + Atime: f.rt, + }, nil +} + +func (fs *memFS) Reload(add func(key, name string)) error { + return nil +} + +func (fs *memFS) Create(key string) (stream.File, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if _, ok := fs.files[key]; ok { + return nil, errors.New("file exists") + } + file := &memFile{ + name: key, + r: bytes.NewBuffer(nil), + wt: time.Now(), + } + file.memReader.memFile = file + fs.files[key] = file + return file, nil +} + +func (fs *memFS) Open(name string) (stream.File, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + if f, ok := fs.files[name]; ok { + f.rt = time.Now() + return &memReader{memFile: f}, nil + } + return nil, errors.New("file does not exist") +} + +func (fs *memFS) Remove(key string) error { + fs.mu.Lock() + defer fs.mu.Unlock() + delete(fs.files, key) + return nil +} + +func (fs *memFS) RemoveAll() error { + fs.mu.Lock() + defer fs.mu.Unlock() + fs.files = make(map[string]*memFile) + return nil +} + +type memFile struct { + mu sync.RWMutex + name string + r *bytes.Buffer + memReader + rt, wt time.Time +} + +func (f *memFile) Name() string { + return f.name +} + +func (f *memFile) Write(p []byte) (int, error) { + if len(p) > 0 { + f.mu.Lock() + defer f.mu.Unlock() + return f.r.Write(p) + } + return len(p), nil +} + +func (f *memFile) Bytes() []byte { + f.mu.RLock() + defer f.mu.RUnlock() + return f.r.Bytes() +} + +func (f *memFile) Close() error { + return nil +} + +type memReader struct { + *memFile + n int +} + +func (r *memReader) ReadAt(p []byte, off int64) (n int, err error) { + data := r.Bytes() + if int64(len(data)) < off { + return 0, io.EOF + } + n, err = bytes.NewReader(data[off:]).ReadAt(p, 0) + return n, err +} + +func (r *memReader) Read(p []byte) (n int, err error) { + n, err = bytes.NewReader(r.Bytes()[r.n:]).Read(p) + r.n += n + return n, err +} + +func (r *memReader) Close() error { + return nil +} diff --git a/libsq/core/ioz/fscache/reaper.go b/libsq/core/ioz/fscache/reaper.go new file mode 100644 index 000000000..d801202a7 --- /dev/null +++ b/libsq/core/ioz/fscache/reaper.go @@ -0,0 +1,37 @@ +package fscache + +import "time" + +// Reaper is used to control when streams expire from the cache. +// It is called once right after loading, and then it is run +// again after every Next() period of time. +type Reaper interface { + // Returns the amount of time to wait before the next scheduled Reaping. + Next() time.Duration + + // Given a key and the last r/w times of a file, return true + // to remove the file from the cache, false to keep it. + Reap(key string, lastRead, lastWrite time.Time) bool +} + +// NewReaper returns a simple reaper which runs every "Period" +// and reaps files which are older than "expiry". +func NewReaper(expiry, period time.Duration) Reaper { + return &reaper{ + expiry: expiry, + period: period, + } +} + +type reaper struct { + period time.Duration + expiry time.Duration +} + +func (g *reaper) Next() time.Duration { + return g.period +} + +func (g *reaper) Reap(key string, lastRead, lastWrite time.Time) bool { + return lastRead.Before(time.Now().Add(-g.expiry)) +} diff --git a/libsq/core/ioz/fscache/server.go b/libsq/core/ioz/fscache/server.go new file mode 100644 index 000000000..dba74aad3 --- /dev/null +++ b/libsq/core/ioz/fscache/server.go @@ -0,0 +1,206 @@ +package fscache + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" +) + +// ListenAndServe hosts a Cache for access via NewRemote +func ListenAndServe(c Cache, addr string) error { + return (&server{c: c}).ListenAndServe(addr) +} + +// NewRemote returns a Cache run via ListenAndServe +func NewRemote(raddr string) Cache { + return &remote{raddr: raddr} +} + +type server struct { + c Cache +} + +func (s *server) ListenAndServe(addr string) error { + l, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + for { + c, err := l.Accept() + if err != nil { + return err + } + + go s.Serve(c) + } +} + +const ( + actionGet = iota + actionRemove = iota + actionExists = iota + actionClean = iota +) + +func getKey(r io.Reader) string { + dec := newDecoder(r) + buf := bytes.NewBufferString("") + io.Copy(buf, dec) + return buf.String() +} + +func sendKey(w io.Writer, key string) { + enc := newEncoder(w) + enc.Write([]byte(key)) + enc.Close() +} + +func (s *server) Serve(c net.Conn) { + var action int + fmt.Fscanf(c, "%d\n", &action) + + switch action { + case actionGet: + s.get(c, getKey(c)) + case actionRemove: + s.c.Remove(getKey(c)) + case actionExists: + s.exists(c, getKey(c)) + case actionClean: + s.c.Clean() + } +} + +func (s *server) exists(c net.Conn, key string) { + if s.c.Exists(key) { + fmt.Fprintf(c, "%d\n", 1) + } else { + fmt.Fprintf(c, "%d\n", 0) + } +} + +func (s *server) get(c net.Conn, key string) { + r, w, err := s.c.Get(key) + if err != nil { + return // handle this better + } + defer r.Close() + + if w != nil { + go func() { + fmt.Fprintf(c, "%d\n", 1) + io.Copy(w, newDecoder(c)) + w.Close() + }() + } else { + fmt.Fprintf(c, "%d\n", 0) + } + + enc := newEncoder(c) + io.Copy(enc, r) + enc.Close() +} + +type remote struct { + raddr string +} + +func (rmt *remote) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { + c, err := net.Dial("tcp", rmt.raddr) + if err != nil { + return nil, nil, err + } + fmt.Fprintf(c, "%d\n", actionGet) + sendKey(c, key) + + var i int + fmt.Fscanf(c, "%d\n", &i) + + var ch chan struct{} + + switch i { + case 0: + ch = make(chan struct{}) // close net.Conn on reader close + case 1: + ch = make(chan struct{}, 1) // two closes before net.Conn close + + w = &safeCloser{ + c: c, + ch: ch, + w: newEncoder(c), + } + default: + return nil, nil, errors.New("bad bad bad") + } + + r = &safeCloser{ + c: c, + ch: ch, + r: newDecoder(c), + } + + return r, w, nil +} + +type safeCloser struct { + c net.Conn + ch chan<- struct{} + r ReadAtCloser + w io.WriteCloser +} + +func (s *safeCloser) ReadAt(p []byte, off int64) (int, error) { + return s.r.ReadAt(p, off) +} +func (s *safeCloser) Read(p []byte) (int, error) { return s.r.Read(p) } +func (s *safeCloser) Write(p []byte) (int, error) { return s.w.Write(p) } + +// Close only closes the underlying connection when ch is full. +func (s *safeCloser) Close() (err error) { + if s.r != nil { + err = s.r.Close() + } else if s.w != nil { + err = s.w.Close() + } + + select { + case s.ch <- struct{}{}: + return err + default: + return s.c.Close() + } +} + +func (rmt *remote) Exists(key string) bool { + c, err := net.Dial("tcp", rmt.raddr) + if err != nil { + return false + } + fmt.Fprintf(c, "%d\n", actionExists) + sendKey(c, key) + var i int + fmt.Fscanf(c, "%d\n", &i) + return i == 1 +} + +func (rmt *remote) Remove(key string) error { + c, err := net.Dial("tcp", rmt.raddr) + if err != nil { + return err + } + fmt.Fprintf(c, "%d\n", actionRemove) + sendKey(c, key) + return nil +} + +func (rmt *remote) Clean() error { + c, err := net.Dial("tcp", rmt.raddr) + if err != nil { + return err + } + fmt.Fprintf(c, "%d\n", actionClean) + return nil +} diff --git a/libsq/core/ioz/fscache/stream.go b/libsq/core/ioz/fscache/stream.go new file mode 100644 index 000000000..9cccb2483 --- /dev/null +++ b/libsq/core/ioz/fscache/stream.go @@ -0,0 +1,72 @@ +package fscache + +import ( + "encoding/json" + "errors" + "io" +) + +type decoder interface { + Decode(interface{}) error +} + +type encoder interface { + Encode(interface{}) error +} + +type pktReader struct { + dec decoder +} + +type pktWriter struct { + enc encoder +} + +type packet struct { + Err int + Data []byte +} + +const eof = 1 + +func (t *pktReader) ReadAt(p []byte, off int64) (n int, err error) { + // TODO not implemented + return 0, errors.New("not implemented") +} + +func (t *pktReader) Read(p []byte) (int, error) { + var pkt packet + err := t.dec.Decode(&pkt) + if err != nil { + return 0, err + } + if pkt.Err == eof { + return 0, io.EOF + } + return copy(p, pkt.Data), nil +} + +func (t *pktReader) Close() error { + return nil +} + +func (t *pktWriter) Write(p []byte) (int, error) { + pkt := packet{Data: p} + err := t.enc.Encode(pkt) + if err != nil { + return 0, err + } + return len(p), nil +} + +func (t *pktWriter) Close() error { + return t.enc.Encode(packet{Err: eof}) +} + +func newEncoder(w io.Writer) io.WriteCloser { + return &pktWriter{enc: json.NewEncoder(w)} +} + +func newDecoder(r io.Reader) ReadAtCloser { + return &pktReader{dec: json.NewDecoder(r)} +} From e79bc8b6aa3d0f03cf885837d18a6e677da6b7a0 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 10:35:05 -0700 Subject: [PATCH 009/195] wip: adding fscache --- libsq/core/ioz/fscache/.gitignore | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 libsq/core/ioz/fscache/.gitignore diff --git a/libsq/core/ioz/fscache/.gitignore b/libsq/core/ioz/fscache/.gitignore new file mode 100644 index 000000000..3881fd0a2 --- /dev/null +++ b/libsq/core/ioz/fscache/.gitignore @@ -0,0 +1,2 @@ +cache* +server/ From 7000dea429a88297f69eb8b53c20d07c0abb30c8 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 10:36:30 -0700 Subject: [PATCH 010/195] wip: adding fscache --- libsq/core/ioz/fscache/.gitignore | 2 - libsq/core/ioz/fscache/LICENSE | 22 - libsq/core/ioz/fscache/README.md | 93 ---- libsq/core/ioz/fscache/distrib.go | 85 ---- libsq/core/ioz/fscache/example_test.go | 69 --- libsq/core/ioz/fscache/fileinfo.go | 52 --- libsq/core/ioz/fscache/fs.go | 266 ------------ libsq/core/ioz/fscache/fscache.go | 373 ---------------- libsq/core/ioz/fscache/fscache_test.go | 579 ------------------------- libsq/core/ioz/fscache/handler.go | 41 -- libsq/core/ioz/fscache/haunter.go | 92 ---- libsq/core/ioz/fscache/layers.go | 128 ------ libsq/core/ioz/fscache/lruhaunter.go | 137 ------ libsq/core/ioz/fscache/memfs.go | 147 ------- libsq/core/ioz/fscache/reaper.go | 37 -- libsq/core/ioz/fscache/server.go | 206 --------- libsq/core/ioz/fscache/stream.go | 72 --- 17 files changed, 2401 deletions(-) delete mode 100644 libsq/core/ioz/fscache/.gitignore delete mode 100644 libsq/core/ioz/fscache/LICENSE delete mode 100644 libsq/core/ioz/fscache/README.md delete mode 100644 libsq/core/ioz/fscache/distrib.go delete mode 100644 libsq/core/ioz/fscache/example_test.go delete mode 100644 libsq/core/ioz/fscache/fileinfo.go delete mode 100644 libsq/core/ioz/fscache/fs.go delete mode 100644 libsq/core/ioz/fscache/fscache.go delete mode 100644 libsq/core/ioz/fscache/fscache_test.go delete mode 100644 libsq/core/ioz/fscache/handler.go delete mode 100644 libsq/core/ioz/fscache/haunter.go delete mode 100644 libsq/core/ioz/fscache/layers.go delete mode 100644 libsq/core/ioz/fscache/lruhaunter.go delete mode 100644 libsq/core/ioz/fscache/memfs.go delete mode 100644 libsq/core/ioz/fscache/reaper.go delete mode 100644 libsq/core/ioz/fscache/server.go delete mode 100644 libsq/core/ioz/fscache/stream.go diff --git a/libsq/core/ioz/fscache/.gitignore b/libsq/core/ioz/fscache/.gitignore deleted file mode 100644 index 3881fd0a2..000000000 --- a/libsq/core/ioz/fscache/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -cache* -server/ diff --git a/libsq/core/ioz/fscache/LICENSE b/libsq/core/ioz/fscache/LICENSE deleted file mode 100644 index 1e7b7cc09..000000000 --- a/libsq/core/ioz/fscache/LICENSE +++ /dev/null @@ -1,22 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2015 Dustin H - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - diff --git a/libsq/core/ioz/fscache/README.md b/libsq/core/ioz/fscache/README.md deleted file mode 100644 index 78b57ef35..000000000 --- a/libsq/core/ioz/fscache/README.md +++ /dev/null @@ -1,93 +0,0 @@ -fscache -========== - -[![GoDoc](https://godoc.org/github.com/djherbis/fscache?status.svg)](https://godoc.org/github.com/djherbis/fscache) -[![Release](https://img.shields.io/github/release/djherbis/fscache.svg)](https://github.com/djherbis/fscache/releases/latest) -[![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE.txt) -[![go test](https://github.com/djherbis/fscache/actions/workflows/go-test.yml/badge.svg)](https://github.com/djherbis/fscache/actions/workflows/go-test.yml) -[![Coverage Status](https://coveralls.io/repos/djherbis/fscache/badge.svg?branch=master)](https://coveralls.io/r/djherbis/fscache?branch=master) -[![Go Report Card](https://goreportcard.com/badge/github.com/djherbis/fscache)](https://goreportcard.com/report/github.com/djherbis/fscache) - -Usage ------------- -Streaming File Cache for #golang - -fscache allows multiple readers to read from a cache while its being written to. [blog post](https://djherbis.github.io/post/fscache/) - -Using the Cache directly: - -```go -package main - -import ( - "io" - "log" - "os" - "time" - - "gopkg.in/djherbis/fscache.v0" -) - -func main() { - - // create the cache, keys expire after 1 hour. - c, err := fscache.New("./cache", 0755, time.Hour) - if err != nil { - log.Fatal(err.Error()) - } - - // wipe the cache when done - defer c.Clean() - - // Get() and it's streams can be called concurrently but just for example: - for i := 0; i < 3; i++ { - r, w, err := c.Get("stream") - if err != nil { - log.Fatal(err.Error()) - } - - if w != nil { // a new stream, write to it. - go func(){ - w.Write([]byte("hello world\n")) - w.Close() - }() - } - - // the stream has started, read from it - io.Copy(os.Stdout, r) - r.Close() - } -} -``` - -A Caching Middle-ware: - -```go -package main - -import( - "net/http" - "time" - - "gopkg.in/djherbis/fscache.v0" -) - -func main(){ - c, err := fscache.New("./cache", 0700, 0) - if err != nil { - log.Fatal(err.Error()) - } - - handler := func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "%v: %s", time.Now(), "hello world") - } - - http.ListenAndServe(":8080", fscache.Handler(c, http.HandlerFunc(handler))) -} -``` - -Installation ------------- -```sh -go get gopkg.in/djherbis/fscache.v0 -``` diff --git a/libsq/core/ioz/fscache/distrib.go b/libsq/core/ioz/fscache/distrib.go deleted file mode 100644 index 60994cc58..000000000 --- a/libsq/core/ioz/fscache/distrib.go +++ /dev/null @@ -1,85 +0,0 @@ -package fscache - -import ( - "bytes" - "crypto/sha1" - "encoding/binary" - "io" -) - -// Distributor provides a way to partition keys into Caches. -type Distributor interface { - - // GetCache will always return the same Cache for the same key. - GetCache(key string) Cache - - // Clean should wipe all the caches this Distributor manages - Clean() error -} - -// stdDistribution distributes the keyspace evenly. -func stdDistribution(key string, n uint64) uint64 { - h := sha1.New() - io.WriteString(h, key) - buf := bytes.NewBuffer(h.Sum(nil)[:8]) - i, _ := binary.ReadUvarint(buf) - return i % n -} - -// NewDistributor returns a Distributor which evenly distributes the keyspace -// into the passed caches. -func NewDistributor(caches ...Cache) Distributor { - if len(caches) == 0 { - return nil - } - return &distrib{ - distribution: stdDistribution, - caches: caches, - size: uint64(len(caches)), - } -} - -type distrib struct { - distribution func(key string, n uint64) uint64 - caches []Cache - size uint64 -} - -func (d *distrib) GetCache(key string) Cache { - return d.caches[d.distribution(key, d.size)] -} - -// BUG(djherbis): Return an error if cleaning fails -func (d *distrib) Clean() error { - for _, c := range d.caches { - c.Clean() - } - return nil -} - -// NewPartition returns a Cache which uses the Caches defined by the passed Distributor. -func NewPartition(d Distributor) Cache { - return &partition{ - distributor: d, - } -} - -type partition struct { - distributor Distributor -} - -func (p *partition) Get(key string) (ReadAtCloser, io.WriteCloser, error) { - return p.distributor.GetCache(key).Get(key) -} - -func (p *partition) Remove(key string) error { - return p.distributor.GetCache(key).Remove(key) -} - -func (p *partition) Exists(key string) bool { - return p.distributor.GetCache(key).Exists(key) -} - -func (p *partition) Clean() error { - return p.distributor.Clean() -} diff --git a/libsq/core/ioz/fscache/example_test.go b/libsq/core/ioz/fscache/example_test.go deleted file mode 100644 index 5aa9e7266..000000000 --- a/libsq/core/ioz/fscache/example_test.go +++ /dev/null @@ -1,69 +0,0 @@ -package fscache - -import ( - "fmt" - "io" - "log" - "net/http" - "net/http/httptest" - "os" - "time" -) - -func Example() { - // create the cache, keys expire after 1 hour. - c, err := New("./cache", 0755, time.Hour) - if err != nil { - log.Fatal(err.Error()) - } - - // wipe the cache when done - defer c.Clean() - - // Get() and it's streams can be called concurrently but just for example: - for i := 0; i < 3; i++ { - r, w, err := c.Get("stream") - if err != nil { - log.Fatal(err.Error()) - } - - if w != nil { // a new stream, write to it. - go func() { - w.Write([]byte("hello world\n")) - w.Close() - }() - } - - // the stream has started, read from it - io.Copy(os.Stdout, r) - r.Close() - } - // Output: - // hello world - // hello world - // hello world -} - -func ExampleHandler() { - c, err := New("./server", 0700, 0) - if err != nil { - log.Fatal(err.Error()) - } - defer c.Clean() - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Hello Client") - }) - - ts := httptest.NewServer(Handler(c, handler)) - defer ts.Close() - - resp, err := http.Get(ts.URL) - if err != nil { - log.Fatal(err.Error()) - } - io.Copy(os.Stdout, resp.Body) - resp.Body.Close() - // Output: - // Hello Client -} diff --git a/libsq/core/ioz/fscache/fileinfo.go b/libsq/core/ioz/fscache/fileinfo.go deleted file mode 100644 index 445fcfd50..000000000 --- a/libsq/core/ioz/fscache/fileinfo.go +++ /dev/null @@ -1,52 +0,0 @@ -package fscache - -import ( - "os" - "time" -) - -// FileInfo is just a wrapper around os.FileInfo which includes atime. -type FileInfo struct { - os.FileInfo - Atime time.Time -} - -type fileInfo struct { - name string - size int64 - fileMode os.FileMode - isDir bool - sys interface{} - wt time.Time -} - -func (f *fileInfo) Name() string { - return f.name -} - -func (f *fileInfo) Size() int64 { - return f.size -} - -func (f *fileInfo) Mode() os.FileMode { - return f.fileMode -} - -func (f *fileInfo) ModTime() time.Time { - return f.wt -} - -func (f *fileInfo) IsDir() bool { - return f.isDir -} - -func (f *fileInfo) Sys() interface{} { - return f.sys -} - -// AccessTime returns the last time the file was read. -// It will be used to check expiry of a file, and must be concurrent safe -// with modifications to the FileSystem (writes, reads etc.) -func (f *FileInfo) AccessTime() time.Time { - return f.Atime -} diff --git a/libsq/core/ioz/fscache/fs.go b/libsq/core/ioz/fscache/fs.go deleted file mode 100644 index dad018382..000000000 --- a/libsq/core/ioz/fscache/fs.go +++ /dev/null @@ -1,266 +0,0 @@ -package fscache - -import ( - "bytes" - "crypto/md5" - "encoding/base64" - "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" - "strings" - "time" - - "github.com/djherbis/atime" - "github.com/djherbis/stream" -) - -// FileSystemStater implementers can provide FileInfo data about a named resource. -type FileSystemStater interface { - // Stat takes a File.Name() and returns FileInfo interface - Stat(name string) (FileInfo, error) -} - -// FileSystem is used as the source for a Cache. -type FileSystem interface { - // Stream FileSystem - stream.FileSystem - - FileSystemStater - - // Reload should look through the FileSystem and call the supplied fn - // with the key/filename pairs that are found. - Reload(func(key, name string)) error - - // RemoveAll should empty the FileSystem of all files. - RemoveAll() error -} - -// StandardFS is an implemenation of FileSystem which writes to the os Filesystem. -type StandardFS struct { - root string - init func() error - - // EncodeKey takes a 'name' given to Create and converts it into a - // the Filename that should be used. It should return 'true' if - // DecodeKey can convert the returned string back to the original 'name' - // and false otherwise. - // This must be set before the first call to Create. - EncodeKey func(string) (string, bool) - - // DecodeKey should convert a given Filename into the original 'name' given to - // EncodeKey, and return true if this conversion was possible. Returning false - // will cause it to try and lookup a stored 'encodedName.key' file which holds - // the original name. - DecodeKey func(string) (string, bool) -} - -// IdentityCodeKey works as both an EncodeKey and a DecodeKey func, which just returns -// it's given argument and true. This is expected to be used when your FSCache -// uses SetKeyMapper to ensure its internal km(key) value is already a valid filename path. -func IdentityCodeKey(key string) (string, bool) { return key, true } - -// NewFs returns a FileSystem rooted at directory dir. -// Dir is created with perms if it doesn't exist. -// This also uses the default EncodeKey/DecodeKey functions B64ORMD5HashEncodeKey/B64DecodeKey. -func NewFs(dir string, mode os.FileMode) (*StandardFS, error) { - fs := &StandardFS{ - root: dir, - init: func() error { - return os.MkdirAll(dir, mode) - }, - EncodeKey: B64OrMD5HashEncodeKey, - DecodeKey: B64DecodeKey, - } - return fs, fs.init() -} - -// Reload looks through the dir given to NewFs and returns every key, name pair (Create(key) => name = File.Name()) -// that is managed by this FileSystem. -func (fs *StandardFS) Reload(add func(key, name string)) error { - files, err := ioutil.ReadDir(fs.root) - if err != nil { - return err - } - - addfiles := make(map[string]struct { - os.FileInfo - key string - }) - - for _, f := range files { - - if strings.HasSuffix(f.Name(), ".key") { - continue - } - - key, err := fs.getKey(f.Name()) - if err != nil { - fs.Remove(filepath.Join(fs.root, f.Name())) - continue - } - fi, ok := addfiles[key] - - if !ok || fi.ModTime().Before(f.ModTime()) { - if ok { - fs.Remove(fi.Name()) - } - addfiles[key] = struct { - os.FileInfo - key string - }{ - FileInfo: f, - key: key, - } - } else { - fs.Remove(f.Name()) - } - - } - - for _, f := range addfiles { - path, err := filepath.Abs(filepath.Join(fs.root, f.Name())) - if err != nil { - return err - } - add(f.key, path) - } - - return nil -} - -// Create creates a File for the given 'name', it may not use the given name on the -// os filesystem, that depends on the implementation of EncodeKey used. -func (fs *StandardFS) Create(name string) (stream.File, error) { - name, err := fs.makeName(name) - if err != nil { - return nil, err - } - return fs.create(name) -} - -func (fs *StandardFS) create(name string) (stream.File, error) { - return os.OpenFile(filepath.Join(fs.root, name), os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) -} - -// Open opens a stream.File for the given File.Name() returned by Create(). -func (fs *StandardFS) Open(name string) (stream.File, error) { - return os.Open(name) -} - -// Remove removes a stream.File for the given File.Name() returned by Create(). -func (fs *StandardFS) Remove(name string) error { - os.Remove(fmt.Sprintf("%s.key", name)) - return os.Remove(name) -} - -// RemoveAll deletes all files in the directory managed by this StandardFS. -// Warning that if you put files in this directory that were not created by -// StandardFS they will also be deleted. -func (fs *StandardFS) RemoveAll() error { - if err := os.RemoveAll(fs.root); err != nil { - return err - } - return fs.init() -} - -// AccessTimes returns atime and mtime for the given File.Name() returned by Create(). -func (fs *StandardFS) AccessTimes(name string) (rt, wt time.Time, err error) { - fi, err := os.Stat(name) - if err != nil { - return rt, wt, err - } - return atime.Get(fi), fi.ModTime(), nil -} - -// Stat returns FileInfo for the given File.Name() returned by Create(). -func (fs *StandardFS) Stat(name string) (FileInfo, error) { - stat, err := os.Stat(name) - if err != nil { - return FileInfo{}, err - } - - return FileInfo{FileInfo: stat, Atime: atime.Get(stat)}, nil -} - -const ( - saltSize = 8 - salt = "xxxxxxxx" // this is only important for sizing now. - maxShort = 20 - shortPrefix = "s" - longPrefix = "l" -) - -func tob64(s string) string { - buf := bytes.NewBufferString("") - enc := base64.NewEncoder(base64.URLEncoding, buf) - enc.Write([]byte(s)) - enc.Close() - return buf.String() -} - -func fromb64(s string) string { - buf := bytes.NewBufferString(s) - dec := base64.NewDecoder(base64.URLEncoding, buf) - out := bytes.NewBufferString("") - io.Copy(out, dec) - return out.String() -} - -// B64OrMD5HashEncodeKey converts a given key into a filesystem name-safe string -// and returns true iff it can be reversed with B64DecodeKey. -func B64OrMD5HashEncodeKey(key string) (string, bool) { - b64key := tob64(key) - // short name - if len(b64key) < maxShort { - return fmt.Sprintf("%s%s%s", shortPrefix, salt, b64key), true - } - - // long name - hash := md5.Sum([]byte(key)) - return fmt.Sprintf("%s%s%x", longPrefix, salt, hash[:]), false -} - -func (fs *StandardFS) makeName(key string) (string, error) { - name, decodable := fs.EncodeKey(key) - if decodable { - return name, nil - } - - // Name is not decodeable, store it. - f, err := fs.create(fmt.Sprintf("%s.key", name)) - if err != nil { - return "", err - } - _, err = f.Write([]byte(key)) - f.Close() - return name, err -} - -// B64DecodeKey converts a string y into x st. y, ok = B64OrMD5HashEncodeKey(x), and ok = true. -// Basically it should reverse B64OrMD5HashEncodeKey if B64OrMD5HashEncodeKey returned true. -func B64DecodeKey(name string) (string, bool) { - if strings.HasPrefix(name, shortPrefix) { - return fromb64(strings.TrimPrefix(name, shortPrefix)[saltSize:]), true - } - return "", false -} - -func (fs *StandardFS) getKey(name string) (string, error) { - if key, ok := fs.DecodeKey(name); ok { - return key, nil - } - - // long name - f, err := fs.Open(filepath.Join(fs.root, fmt.Sprintf("%s.key", name))) - if err != nil { - return "", err - } - defer f.Close() - key, err := ioutil.ReadAll(f) - if err != nil { - return "", err - } - return string(key), nil -} diff --git a/libsq/core/ioz/fscache/fscache.go b/libsq/core/ioz/fscache/fscache.go deleted file mode 100644 index 6de40a3b8..000000000 --- a/libsq/core/ioz/fscache/fscache.go +++ /dev/null @@ -1,373 +0,0 @@ -package fscache - -import ( - "fmt" - "io" - "os" - "sync" - "sync/atomic" - "time" - - "github.com/djherbis/stream" -) - -// Cache works like a concurrent-safe map for streams. -type Cache interface { - // Get manages access to the streams in the cache. - // If the key does not exist, w != nil and you can start writing to the stream. - // If the key does exist, w == nil. - // r will always be non-nil as long as err == nil and you must close r when you're done reading. - // Get can be called concurrently, and writing and reading is concurrent safe. - Get(key string) (ReadAtCloser, io.WriteCloser, error) - - // Remove deletes the stream from the cache, blocking until the underlying - // file can be deleted (all active streams finish with it). - // It is safe to call Remove concurrently with Get. - Remove(key string) error - - // Exists checks if a key is in the cache. - // It is safe to call Exists concurrently with Get. - Exists(key string) bool - - // Clean will empty the cache and delete the cache folder. - // Clean is not safe to call while streams are being read/written. - Clean() error -} - -// FSCache is a Cache which uses a Filesystem to read/write cached data. -type FSCache struct { - mu sync.RWMutex - files map[string]fileStream - km func(string) string - fs FileSystem - haunter Haunter -} - -// SetKeyMapper will use the given function to transform any given Cache key into the result of km(key). -// This means that internally, the cache will only track km(key), and forget the original key. The consequences -// of this are that Enumerate will return km(key) instead of key, and Filesystem will give km(key) to Create -// and expect Reload() to return km(key). -// The purpose of this function is so that the internally managed key can be converted to a string that is -// allowed as a filesystem path. -func (c *FSCache) SetKeyMapper(km func(string) string) *FSCache { - c.mu.Lock() - defer c.mu.Unlock() - c.km = km - return c -} - -func (c *FSCache) mapKey(key string) string { - if c.km == nil { - return key - } - return c.km(key) -} - -// ReadAtCloser is an io.ReadCloser, and an io.ReaderAt. It supports both so that Range -// Requests are possible. -type ReadAtCloser interface { - io.ReadCloser - io.ReaderAt -} - -type fileStream interface { - next() (*CacheReader, error) - InUse() bool - io.WriteCloser - remove() error - Name() string -} - -// New creates a new Cache using NewFs(dir, perms). -// expiry is the duration after which an un-accessed key will be removed from -// the cache, a zero value expiro means never expire. -func New(dir string, perms os.FileMode, expiry time.Duration) (*FSCache, error) { - fs, err := NewFs(dir, perms) - if err != nil { - return nil, err - } - var grim Reaper - if expiry > 0 { - grim = &reaper{ - expiry: expiry, - period: expiry, - } - } - return NewCache(fs, grim) -} - -// NewCache creates a new Cache based on FileSystem fs. -// fs.Files() are loaded using the name they were created with as a key. -// Reaper is used to determine when files expire, nil means never expire. -func NewCache(fs FileSystem, grim Reaper) (*FSCache, error) { - if grim != nil { - return NewCacheWithHaunter(fs, NewReaperHaunterStrategy(grim)) - } - - return NewCacheWithHaunter(fs, nil) -} - -// NewCacheWithHaunter create a new Cache based on FileSystem fs. -// fs.Files() are loaded using the name they were created with as a key. -// Haunter is used to determine when files expire, nil means never expire. -func NewCacheWithHaunter(fs FileSystem, haunter Haunter) (*FSCache, error) { - c := &FSCache{ - files: make(map[string]fileStream), - haunter: haunter, - fs: fs, - } - err := c.load() - if err != nil { - return nil, err - } - if haunter != nil { - c.scheduleHaunt() - } - - return c, nil -} - -func (c *FSCache) scheduleHaunt() { - c.haunt() - time.AfterFunc(c.haunter.Next(), c.scheduleHaunt) -} - -func (c *FSCache) haunt() { - c.mu.Lock() - defer c.mu.Unlock() - - c.haunter.Haunt(&accessor{c: c}) -} - -func (c *FSCache) load() error { - c.mu.Lock() - defer c.mu.Unlock() - return c.fs.Reload(func(key, name string) { - c.files[key] = c.oldFile(name) - }) -} - -// Exists returns true iff this key is in the Cache (may not be finished streaming). -func (c *FSCache) Exists(key string) bool { - c.mu.RLock() - defer c.mu.RUnlock() - _, ok := c.files[c.mapKey(key)] - return ok -} - -// Get obtains a ReadAtCloser for the given key, and may return a WriteCloser to write the original cache data -// if this is a cache-miss. -func (c *FSCache) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { - c.mu.RLock() - key = c.mapKey(key) - f, ok := c.files[key] - if ok { - r, err = f.next() - c.mu.RUnlock() - return r, nil, err - } - c.mu.RUnlock() - - c.mu.Lock() - defer c.mu.Unlock() - - f, ok = c.files[key] - if ok { - r, err = f.next() - return r, nil, err - } - - f, err = c.newFile(key) - if err != nil { - return nil, nil, err - } - - r, err = f.next() - if err != nil { - f.Close() - c.fs.Remove(f.Name()) - return nil, nil, err - } - - c.files[key] = f - - return r, f, err -} - -// Remove removes the specified key from the cache. -func (c *FSCache) Remove(key string) error { - c.mu.Lock() - key = c.mapKey(key) - f, ok := c.files[key] - delete(c.files, key) - c.mu.Unlock() - - if ok { - return f.remove() - } - return nil -} - -// Clean resets the cache removing all keys and data. -func (c *FSCache) Clean() error { - c.mu.Lock() - defer c.mu.Unlock() - c.files = make(map[string]fileStream) - return c.fs.RemoveAll() -} - -type accessor struct { - c *FSCache -} - -func (a *accessor) Stat(name string) (FileInfo, error) { - return a.c.fs.Stat(name) -} - -func (a *accessor) EnumerateEntries(enumerator func(key string, e Entry) bool) { - for k, f := range a.c.files { - if !enumerator(k, Entry{name: f.Name(), inUse: f.InUse()}) { - break - } - } -} - -func (a *accessor) RemoveFile(key string) { - key = a.c.mapKey(key) - f, ok := a.c.files[key] - delete(a.c.files, key) - if ok { - a.c.fs.Remove(f.Name()) - } -} - -type cachedFile struct { - handleCounter - stream *stream.Stream -} - -func (c *FSCache) newFile(name string) (fileStream, error) { - s, err := stream.NewStream(name, c.fs) - if err != nil { - return nil, err - } - cf := &cachedFile{ - stream: s, - } - cf.inc() - return cf, nil -} - -func (c *FSCache) oldFile(name string) fileStream { - return &reloadedFile{ - fs: c.fs, - name: name, - } -} - -type reloadedFile struct { - handleCounter - fs FileSystem - name string - io.WriteCloser // nop Write & Close methods. will never be called. -} - -func (f *reloadedFile) Name() string { return f.name } - -func (f *reloadedFile) remove() error { - f.waitUntilFree() - return f.fs.Remove(f.name) -} - -func (f *reloadedFile) next() (*CacheReader, error) { - r, err := f.fs.Open(f.name) - if err == nil { - f.inc() - } - return &CacheReader{ - ReadAtCloser: r, - cnt: &f.handleCounter, - }, err -} - -func (f *cachedFile) Name() string { return f.stream.Name() } - -func (f *cachedFile) remove() error { return f.stream.Remove() } - -func (f *cachedFile) next() (*CacheReader, error) { - reader, err := f.stream.NextReader() - if err != nil { - return nil, err - } - f.inc() - return &CacheReader{ - ReadAtCloser: reader, - cnt: &f.handleCounter, - }, nil -} - -func (f *cachedFile) Write(p []byte) (int, error) { - return f.stream.Write(p) -} - -func (f *cachedFile) Close() error { - defer f.dec() - return f.stream.Close() -} - -// CacheReader is a ReadAtCloser for a Cache key that also tracks open readers. -type CacheReader struct { - ReadAtCloser - cnt *handleCounter -} - -// Close frees the underlying ReadAtCloser and updates the open reader counter. -func (r *CacheReader) Close() error { - defer r.cnt.dec() - return r.ReadAtCloser.Close() -} - -// Size returns the current size of the stream being read, the boolean it -// returns is true iff the stream is done being written (otherwise Size may change). -// An error is returned if the Size fails to be computed or is not supported -// by the underlying filesystem. -func (r *CacheReader) Size() (int64, bool, error) { - switch v := r.ReadAtCloser.(type) { - case *stream.Reader: - size, done := v.Size() - return size, done, nil - - case interface{ Stat() (os.FileInfo, error) }: - fi, err := v.Stat() - if err != nil { - return 0, false, err - } - return fi.Size(), true, nil - - default: - return 0, false, fmt.Errorf("reader does not support stat") - } -} - -type handleCounter struct { - cnt int64 - grp sync.WaitGroup -} - -func (h *handleCounter) inc() { - h.grp.Add(1) - atomic.AddInt64(&h.cnt, 1) -} - -func (h *handleCounter) dec() { - atomic.AddInt64(&h.cnt, -1) - h.grp.Done() -} - -func (h *handleCounter) InUse() bool { - return atomic.LoadInt64(&h.cnt) > 0 -} - -func (h *handleCounter) waitUntilFree() { - h.grp.Wait() -} diff --git a/libsq/core/ioz/fscache/fscache_test.go b/libsq/core/ioz/fscache/fscache_test.go deleted file mode 100644 index de125299a..000000000 --- a/libsq/core/ioz/fscache/fscache_test.go +++ /dev/null @@ -1,579 +0,0 @@ -package fscache - -import ( - "bytes" - "crypto/md5" - "fmt" - "io" - "io/ioutil" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "testing" - "time" -) - -func createFile(name string) (*os.File, error) { - return os.OpenFile(name, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) -} - -func init() { - c, _ := NewCache(NewMemFs(), nil) - go ListenAndServe(c, "localhost:10000") -} - -func testCaches(t *testing.T, run func(c Cache)) { - c, err := New("./cache", 0700, 1*time.Hour) - if err != nil { - t.Error(err.Error()) - return - } - run(c) - - c, err = NewCache(NewMemFs(), NewReaper(time.Hour, time.Hour)) - if err != nil { - t.Error(err.Error()) - return - } - run(c) - - c2, _ := NewCache(NewMemFs(), nil) - run(NewPartition(NewDistributor(c, c2))) - - lc := NewLayered(c, c2) - run(lc) - - rc := NewRemote("localhost:10000") - run(rc) - - fs, _ := NewFs("./cachex", 0700) - fs.EncodeKey = IdentityCodeKey - fs.DecodeKey = IdentityCodeKey - ck, _ := NewCache(fs, NewReaper(time.Hour, time.Hour)) - ck.SetKeyMapper(func(key string) string { - name, _ := B64OrMD5HashEncodeKey(key) - return name - }) - run(ck) -} - -func TestHandler(t *testing.T) { - testCaches(t, func(c Cache) { - defer c.Clean() - ts := httptest.NewServer(Handler(c, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, "Hello Client") - }))) - defer ts.Close() - - for i := 0; i < 3; i++ { - res, err := http.Get(ts.URL) - if err != nil { - t.Error(err.Error()) - t.FailNow() - } - p, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - if !bytes.Equal([]byte("Hello Client\n"), p) { - t.Errorf("unexpected response %s", string(p)) - } - } - }) -} - -func TestMemFs(t *testing.T) { - fs := NewMemFs() - fs.Reload(func(key, name string) {}) // nop - if _, err := fs.Open("test"); err == nil { - t.Errorf("stream shouldn't exist") - } - fs.Remove("test") - - f, err := fs.Create("test") - if err != nil { - t.Errorf("failed to create test") - } - f.Write([]byte("hello")) - f.Close() - - r, err := fs.Open("test") - if err != nil { - t.Errorf("failed Open: %v", err) - } - p, err := ioutil.ReadAll(r) - if err != nil { - t.Errorf("failed ioutil.ReadAll: %v", err) - } - r.Close() - if !bytes.Equal(p, []byte("hello")) { - t.Errorf("expected hello, got %s", string(p)) - } - fs.RemoveAll() -} - -func TestLoadCleanup1(t *testing.T) { - os.Mkdir("./cache6", 0700) - f, err := createFile(filepath.Join("./cache6", "s11111111"+tob64("test"))) - if err != nil { - t.Error(err.Error()) - } - f.Close() - <-time.After(time.Second) - f, err = createFile(filepath.Join("./cache6", "s22222222"+tob64("test"))) - if err != nil { - t.Error(err.Error()) - } - f.Close() - - c, err := New("./cache6", 0700, 0) - if err != nil { - t.Error(err.Error()) - return - } - defer c.Clean() - - if !c.Exists("test") { - t.Errorf("expected test to exist") - } -} - -const longString = ` - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 - 0123456789 0123456789 -` - -func TestLoadCleanup2(t *testing.T) { - hash := md5.Sum([]byte(longString)) - name2 := fmt.Sprintf("%s%s%x", longPrefix, "22222222", hash[:]) - name1 := fmt.Sprintf("%s%s%x", longPrefix, "11111111", hash[:]) - - os.Mkdir("./cache7", 0700) - f, err := createFile(filepath.Join("./cache7", name2)) - if err != nil { - t.Error(err.Error()) - } - f.Close() - f, err = createFile(filepath.Join("./cache7", fmt.Sprintf("%s.key", name2))) - if err != nil { - t.Error(err.Error()) - } - f.Write([]byte(longString)) - f.Close() - <-time.After(time.Second) - f, err = createFile(filepath.Join("./cache7", name1)) - if err != nil { - t.Error(err.Error()) - } - f.Close() - f, err = createFile(filepath.Join("./cache7", fmt.Sprintf("%s.key", name1))) - if err != nil { - t.Error(err.Error()) - } - f.Write([]byte(longString)) - f.Close() - - c, err := New("./cache7", 0700, 0) - if err != nil { - t.Error(err.Error()) - return - } - defer c.Clean() - - if !c.Exists(longString) { - t.Errorf("expected test to exist") - } -} - -func TestReload(t *testing.T) { - dir, err := ioutil.TempDir("", "cache5") - if err != nil { - t.Fatalf("Failed to create TempDir: %v", err) - } - c, err := New(dir, 0700, 0) - if err != nil { - t.Error(err.Error()) - return - } - r, w, err := c.Get("stream") - if err != nil { - t.Error(err.Error()) - return - } - r.Close() - data := []byte("hello world\n") - w.Write(data) - w.Close() - - nc, err := New(dir, 0700, 0) - if err != nil { - t.Error(err.Error()) - return - } - defer nc.Clean() - - if !nc.Exists("stream") { - t.Fatalf("expected stream to be reloaded") - } - - r, w, err = nc.Get("stream") - if err != nil { - t.Fatal(err) - } - if w != nil { - t.Fatal("expected reloaded stream to not be writable") - } - - cr, ok := r.(*CacheReader) - if !ok { - t.Fatalf("CacheReader should be supported by a normal FS") - } - size, closed, err := cr.Size() - if err != nil { - t.Fatalf("Failed to get Size: %v", err) - } - if !closed { - t.Errorf("Expected stream to be closed.") - } - if size != int64(len(data)) { - t.Errorf("Expected size to be %v, but got %v", len(data), size) - } - - r.Close() - nc.Remove("stream") - if nc.Exists("stream") { - t.Errorf("expected stream to be removed") - } -} - -func TestLRUHaunterMaxItems(t *testing.T) { - - fs, err := NewFs("./cache1", 0700) - if err != nil { - t.Error(err.Error()) - t.FailNow() - } - - c, err := NewCacheWithHaunter(fs, NewLRUHaunterStrategy(NewLRUHaunter(3, 0, 400*time.Millisecond))) - - if err != nil { - t.Error(err.Error()) - return - } - defer c.Clean() - - for i := 0; i < 5; i++ { - name := fmt.Sprintf("stream-%v", i) - r, w, _ := c.Get(name) - w.Write([]byte("hello")) - w.Close() - io.Copy(ioutil.Discard, r) - - if !c.Exists(name) { - t.Errorf(name + " should exist") - } - - <-time.After(10 * time.Millisecond) - - err := r.Close() - if err != nil { - t.Error(err) - } - } - - <-time.After(400 * time.Millisecond) - - if c.Exists("stream-0") { - t.Errorf("stream-0 should have been scrubbed") - } - - if c.Exists("stream-1") { - t.Errorf("stream-1 should have been scrubbed") - } - - files, err := ioutil.ReadDir("./cache1") - if err != nil { - t.Error(err.Error()) - return - } - - if len(files) != 3 { - t.Errorf("expected 3 items in directory") - } -} - -func TestLRUHaunterMaxSize(t *testing.T) { - - fs, err := NewFs("./cache1", 0700) - if err != nil { - t.Error(err.Error()) - t.FailNow() - } - - c, err := NewCacheWithHaunter(fs, NewLRUHaunterStrategy(NewLRUHaunter(0, 24, 400*time.Millisecond))) - - if err != nil { - t.Error(err.Error()) - return - } - defer c.Clean() - - for i := 0; i < 5; i++ { - name := fmt.Sprintf("stream-%v", i) - r, w, _ := c.Get(name) - w.Write([]byte("hello")) - w.Close() - io.Copy(ioutil.Discard, r) - - if !c.Exists(name) { - t.Errorf(name + " should exist") - } - - <-time.After(10 * time.Millisecond) - - err := r.Close() - if err != nil { - t.Error(err) - } - } - - <-time.After(400 * time.Millisecond) - - if c.Exists("stream-0") { - t.Errorf("stream-0 should have been scrubbed") - } - - files, err := ioutil.ReadDir("./cache1") - if err != nil { - t.Error(err.Error()) - return - } - - if len(files) != 4 { - t.Errorf("expected 4 items in directory") - } -} - -func TestReaper(t *testing.T) { - fs, err := NewFs("./cache1", 0700) - if err != nil { - t.Error(err.Error()) - t.FailNow() - } - - c, err := NewCache(fs, NewReaper(0*time.Second, 100*time.Millisecond)) - if err != nil { - t.Fatal(err) - } - defer c.Clean() - - r, w, err := c.Get("stream") - if err != nil { - t.Fatal(err) - } - w.Write([]byte("hello")) - w.Close() - io.Copy(ioutil.Discard, r) - - if !c.Exists("stream") { - t.Errorf("stream should exist") - } - - <-time.After(200 * time.Millisecond) - - if !c.Exists("stream") { - t.Errorf("a file expired while in use, fail!") - } - r.Close() - - <-time.After(200 * time.Millisecond) - - if c.Exists("stream") { - t.Errorf("stream should have been reaped") - } - - files, err := ioutil.ReadDir("./cache1") - if err != nil { - t.Error(err.Error()) - return - } - - if len(files) > 0 { - t.Errorf("expected empty directory") - } -} - -func TestReaperNoExpire(t *testing.T) { - testCaches(t, func(c Cache) { - defer c.Clean() - r, w, err := c.Get("stream") - if err != nil { - t.Error(err.Error()) - t.FailNow() - } - w.Write([]byte("hello")) - w.Close() - io.Copy(ioutil.Discard, r) - r.Close() - - if !c.Exists("stream") { - t.Errorf("stream should exist") - } - - if lc, ok := c.(*FSCache); ok { - lc.haunt() - if !c.Exists("stream") { - t.Errorf("stream shouldn't have been reaped") - } - } - }) -} - -func TestSanity(t *testing.T) { - atLeastOneCacheReader := false - testCaches(t, func(c Cache) { - defer c.Clean() - - r, w, err := c.Get(longString) - if err != nil { - t.Error(err.Error()) - return - } - defer r.Close() - - want := []byte("hello world\n") - first := want[:5] - w.Write(first) - - cr, ok := r.(*CacheReader) - if ok { - atLeastOneCacheReader = true - size, closed, _ := cr.Size() - if closed { - t.Errorf("Expected stream to be open.") - } - if size != int64(len(first)) { - t.Errorf("Expected size to be %v, but got %v", len(first), size) - } - } - - second := want[5:] - w.Write(second) - - if ok { - atLeastOneCacheReader = true - size, closed, _ := cr.Size() - if closed { - t.Errorf("Expected stream to be open.") - } - if size != int64(len(want)) { - t.Errorf("Expected size to be %v, but got %v", len(want), size) - } - } - - w.Close() - - if ok { - atLeastOneCacheReader = true - size, closed, _ := cr.Size() - if !closed { - t.Errorf("Expected stream to be closed.") - } - if size != int64(len(want)) { - t.Errorf("Expected size to be %v, but got %v", len(want), size) - } - } - - buf := bytes.NewBuffer(nil) - _, err = io.Copy(buf, r) - if err != nil { - t.Error(err.Error()) - return - } - if !bytes.Equal(buf.Bytes(), want) { - t.Errorf("unexpected output %s", buf.Bytes()) - } - }) - if !atLeastOneCacheReader { - t.Errorf("None of the cache tests covered CacheReader!") - } -} - -func TestConcurrent(t *testing.T) { - testCaches(t, func(c Cache) { - defer c.Clean() - - r, w, err := c.Get("stream") - r.Close() - if err != nil { - t.Error(err.Error()) - return - } - go func() { - w.Write([]byte("hello")) - <-time.After(100 * time.Millisecond) - w.Write([]byte("world")) - w.Close() - }() - - if c.Exists("stream") { - r, _, err := c.Get("stream") - if err != nil { - t.Error(err.Error()) - return - } - buf := bytes.NewBuffer(nil) - io.Copy(buf, r) - r.Close() - if !bytes.Equal(buf.Bytes(), []byte("helloworld")) { - t.Errorf("unexpected output %s", buf.Bytes()) - } - } - }) -} - -func TestReuse(t *testing.T) { - testCaches(t, func(c Cache) { - for i := 0; i < 10; i++ { - r, w, err := c.Get(longString) - if err != nil { - t.Error(err.Error()) - return - } - - data := fmt.Sprintf("hello %d", i) - - if w != nil { - w.Write([]byte(data)) - w.Close() - } - - check(t, r, data) - r.Close() - - c.Clean() - } - }) -} - -func check(t *testing.T, r io.Reader, data string) { - buf := bytes.NewBuffer(nil) - _, err := io.Copy(buf, r) - if err != nil { - t.Error(err.Error()) - return - } - if !bytes.Equal(buf.Bytes(), []byte(data)) { - t.Errorf("unexpected output %q, want %q", buf.String(), data) - } -} diff --git a/libsq/core/ioz/fscache/handler.go b/libsq/core/ioz/fscache/handler.go deleted file mode 100644 index 8df85400c..000000000 --- a/libsq/core/ioz/fscache/handler.go +++ /dev/null @@ -1,41 +0,0 @@ -package fscache - -import ( - "io" - "net/http" -) - -// Handler is a caching middle-ware for http Handlers. -// It responds to http requests via the passed http.Handler, and caches the response -// using the passed cache. The cache key for the request is the req.URL.String(). -// Note: It does not cache http headers. It is more efficient to set them yourself. -func Handler(c Cache, h http.Handler) http.Handler { - return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - url := req.URL.String() - r, w, err := c.Get(url) - if err != nil { - h.ServeHTTP(rw, req) - return - } - defer r.Close() - if w != nil { - go func() { - defer w.Close() - h.ServeHTTP(&respWrapper{ - ResponseWriter: rw, - Writer: w, - }, req) - }() - } - io.Copy(rw, r) - }) -} - -type respWrapper struct { - http.ResponseWriter - io.Writer -} - -func (r *respWrapper) Write(p []byte) (int, error) { - return r.Writer.Write(p) -} diff --git a/libsq/core/ioz/fscache/haunter.go b/libsq/core/ioz/fscache/haunter.go deleted file mode 100644 index a8d038ce9..000000000 --- a/libsq/core/ioz/fscache/haunter.go +++ /dev/null @@ -1,92 +0,0 @@ -package fscache - -import ( - "time" -) - -// Entry represents a cached item. -type Entry struct { - name string - inUse bool -} - -// InUse returns if this Cache entry is in use. -func (e *Entry) InUse() bool { - return e.inUse -} - -// Name returns the File.Name() of this entry. -func (e *Entry) Name() string { - return e.name -} - -// CacheAccessor implementors provide ways to observe and interact with -// the cached entries, mainly used for cache-eviction. -type CacheAccessor interface { - FileSystemStater - EnumerateEntries(enumerator func(key string, e Entry) bool) - RemoveFile(key string) -} - -// Haunter implementors are used to perform cache-eviction (Next is how long to wait -// until next evication, Haunt preforms the eviction). -type Haunter interface { - Haunt(c CacheAccessor) - Next() time.Duration -} - -type reaperHaunterStrategy struct { - reaper Reaper -} - -type lruHaunterStrategy struct { - haunter LRUHaunter -} - -// NewLRUHaunterStrategy returns a simple scheduleHaunt which provides an implementation LRUHaunter strategy -func NewLRUHaunterStrategy(haunter LRUHaunter) Haunter { - return &lruHaunterStrategy{ - haunter: haunter, - } -} - -func (h *lruHaunterStrategy) Haunt(c CacheAccessor) { - for _, key := range h.haunter.Scrub(c) { - c.RemoveFile(key) - } - -} - -func (h *lruHaunterStrategy) Next() time.Duration { - return h.haunter.Next() -} - -// NewReaperHaunterStrategy returns a simple scheduleHaunt which provides an implementation Reaper strategy -func NewReaperHaunterStrategy(reaper Reaper) Haunter { - return &reaperHaunterStrategy{ - reaper: reaper, - } -} - -func (h *reaperHaunterStrategy) Haunt(c CacheAccessor) { - c.EnumerateEntries(func(key string, e Entry) bool { - if e.InUse() { - return true - } - - fileInfo, err := c.Stat(e.Name()) - if err != nil { - return true - } - - if h.reaper.Reap(key, fileInfo.AccessTime(), fileInfo.ModTime()) { - c.RemoveFile(key) - } - - return true - }) -} - -func (h *reaperHaunterStrategy) Next() time.Duration { - return h.reaper.Next() -} diff --git a/libsq/core/ioz/fscache/layers.go b/libsq/core/ioz/fscache/layers.go deleted file mode 100644 index b0b283106..000000000 --- a/libsq/core/ioz/fscache/layers.go +++ /dev/null @@ -1,128 +0,0 @@ -package fscache - -import ( - "errors" - "io" - "sync" -) - -type layeredCache struct { - layers []Cache -} - -// NewLayered returns a Cache which stores its data in all the passed -// caches, when a key is requested it is loaded into all the caches above the first hit. -func NewLayered(caches ...Cache) Cache { - return &layeredCache{layers: caches} -} - -func (l *layeredCache) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { - var last ReadAtCloser - var writers []io.WriteCloser - - for i, layer := range l.layers { - r, w, err = layer.Get(key) - if err != nil { - if len(writers) > 0 { - last.Close() - multiWC(writers...).Close() - } - return nil, nil, err - } - - // hit - if w == nil { - if len(writers) > 0 { - go func(r io.ReadCloser) { - wc := multiWC(writers...) - defer r.Close() - defer wc.Close() - io.Copy(wc, r) - }(r) - return last, nil, nil - } - return r, nil, nil - } - - // miss - writers = append(writers, w) - - if i == len(l.layers)-1 { - if last != nil { - last.Close() - } - return r, multiWC(writers...), nil - } - - if last != nil { - last.Close() - } - last = r - } - - return nil, nil, errors.New("no caches") -} - -func (l *layeredCache) Remove(key string) error { - var grp sync.WaitGroup - // walk upwards so that lower layers don't - // restore upper layers on Get() - for i := len(l.layers) - 1; i >= 0; i-- { - grp.Add(1) - go func(layer Cache) { - defer grp.Done() - layer.Remove(key) - }(l.layers[i]) - } - grp.Wait() - return nil -} - -func (l *layeredCache) Exists(key string) bool { - for _, layer := range l.layers { - if layer.Exists(key) { - return true - } - } - return false -} - -func (l *layeredCache) Clean() (error) { - for _, layer := range l.layers { - if err := layer.Clean(); err != nil { - return err - } - } - return nil -} - -func multiWC(wc ...io.WriteCloser) io.WriteCloser { - if len(wc) == 0 { - return nil - } - - return &multiWriteCloser{ - writers: wc, - } -} - -type multiWriteCloser struct { - writers []io.WriteCloser -} - -func (t *multiWriteCloser) Write(p []byte) (n int, err error) { - for _, w := range t.writers { - n, err = w.Write(p) - if err != nil { - return - } - } - return len(p), nil -} - -func (t *multiWriteCloser) Close() error { - for _, w := range t.writers { - w.Close() - } - return nil -} diff --git a/libsq/core/ioz/fscache/lruhaunter.go b/libsq/core/ioz/fscache/lruhaunter.go deleted file mode 100644 index 7b90ef3a7..000000000 --- a/libsq/core/ioz/fscache/lruhaunter.go +++ /dev/null @@ -1,137 +0,0 @@ -package fscache - -import ( - "sort" - "time" -) - -type lruHaunterKV struct { - Key string - Value Entry -} - -// LRUHaunter is used to control when there are too many streams -// or the size of the streams is too big. -// It is called once right after loading, and then it is run -// again after every Next() period of time. -type LRUHaunter interface { - // Returns the amount of time to wait before the next scheduled Reaping. - Next() time.Duration - - // Given a CacheAccessor, return keys to reap list. - Scrub(c CacheAccessor) []string -} - -// NewLRUHaunter returns a simple haunter which runs every "period" -// and scrubs older files when the total file size is over maxSize or -// total item count is over maxItems. -// If maxItems or maxSize are 0, they won't be checked -func NewLRUHaunter(maxItems int, maxSize int64, period time.Duration) LRUHaunter { - return &lruHaunter{ - period: period, - maxItems: maxItems, - maxSize: maxSize, - } -} - -type lruHaunter struct { - period time.Duration - maxItems int - maxSize int64 -} - -func (j *lruHaunter) Next() time.Duration { - return j.period -} - -func (j *lruHaunter) Scrub(c CacheAccessor) (keysToReap []string) { - var count int - var size int64 - var okFiles []lruHaunterKV - - c.EnumerateEntries(func(key string, e Entry) bool { - if e.InUse() { - return true - } - - fileInfo, err := c.Stat(e.Name()) - if err != nil { - return true - } - - count++ - size = size + fileInfo.Size() - okFiles = append(okFiles, lruHaunterKV{ - Key: key, - Value: e, - }) - - return true - }) - - sort.Slice(okFiles, func(i, j int) bool { - iFileInfo, err := c.Stat(okFiles[i].Value.Name()) - if err != nil { - return false - } - - iLastRead := iFileInfo.AccessTime() - - jFileInfo, err := c.Stat(okFiles[j].Value.Name()) - if err != nil { - return false - } - - jLastRead := jFileInfo.AccessTime() - - return iLastRead.Before(jLastRead) - }) - - collectKeysToReapFn := func() bool { - var key *string - var err error - key, count, size, err = j.removeFirst(c, &okFiles, count, size) - if err != nil { - return false - } - if key != nil { - keysToReap = append(keysToReap, *key) - } - - return true - } - - if j.maxItems > 0 { - for count > j.maxItems { - if !collectKeysToReapFn() { - break - } - } - } - - if j.maxSize > 0 { - for size > j.maxSize { - if !collectKeysToReapFn() { - break - } - } - } - - return keysToReap -} - -func (j *lruHaunter) removeFirst(fsStater FileSystemStater, items *[]lruHaunterKV, count int, size int64) (*string, int, int64, error) { - var f lruHaunterKV - - f, *items = (*items)[0], (*items)[1:] - - fileInfo, err := fsStater.Stat(f.Value.Name()) - if err != nil { - return nil, count, size, err - } - - count-- - size = size - fileInfo.Size() - - return &f.Key, count, size, nil -} diff --git a/libsq/core/ioz/fscache/memfs.go b/libsq/core/ioz/fscache/memfs.go deleted file mode 100644 index ddfb92cf3..000000000 --- a/libsq/core/ioz/fscache/memfs.go +++ /dev/null @@ -1,147 +0,0 @@ -package fscache - -import ( - "bytes" - "errors" - "io" - "os" - "sync" - "time" - - "github.com/djherbis/stream" -) - -type memFS struct { - mu sync.RWMutex - files map[string]*memFile -} - -// NewMemFs creates an in-memory FileSystem. -// It does not support persistence (Reload is a nop). -func NewMemFs() FileSystem { - return &memFS{ - files: make(map[string]*memFile), - } -} - -func (fs *memFS) Stat(name string) (FileInfo, error) { - fs.mu.RLock() - defer fs.mu.RUnlock() - f, ok := fs.files[name] - if !ok { - return FileInfo{}, errors.New("file has not been read") - } - - size := int64(len(f.Bytes())) - - return FileInfo{ - FileInfo: &fileInfo{ - name: name, - size: size, - fileMode: os.ModeIrregular, - isDir: false, - sys: nil, - wt: f.wt, - }, - Atime: f.rt, - }, nil -} - -func (fs *memFS) Reload(add func(key, name string)) error { - return nil -} - -func (fs *memFS) Create(key string) (stream.File, error) { - fs.mu.Lock() - defer fs.mu.Unlock() - if _, ok := fs.files[key]; ok { - return nil, errors.New("file exists") - } - file := &memFile{ - name: key, - r: bytes.NewBuffer(nil), - wt: time.Now(), - } - file.memReader.memFile = file - fs.files[key] = file - return file, nil -} - -func (fs *memFS) Open(name string) (stream.File, error) { - fs.mu.Lock() - defer fs.mu.Unlock() - if f, ok := fs.files[name]; ok { - f.rt = time.Now() - return &memReader{memFile: f}, nil - } - return nil, errors.New("file does not exist") -} - -func (fs *memFS) Remove(key string) error { - fs.mu.Lock() - defer fs.mu.Unlock() - delete(fs.files, key) - return nil -} - -func (fs *memFS) RemoveAll() error { - fs.mu.Lock() - defer fs.mu.Unlock() - fs.files = make(map[string]*memFile) - return nil -} - -type memFile struct { - mu sync.RWMutex - name string - r *bytes.Buffer - memReader - rt, wt time.Time -} - -func (f *memFile) Name() string { - return f.name -} - -func (f *memFile) Write(p []byte) (int, error) { - if len(p) > 0 { - f.mu.Lock() - defer f.mu.Unlock() - return f.r.Write(p) - } - return len(p), nil -} - -func (f *memFile) Bytes() []byte { - f.mu.RLock() - defer f.mu.RUnlock() - return f.r.Bytes() -} - -func (f *memFile) Close() error { - return nil -} - -type memReader struct { - *memFile - n int -} - -func (r *memReader) ReadAt(p []byte, off int64) (n int, err error) { - data := r.Bytes() - if int64(len(data)) < off { - return 0, io.EOF - } - n, err = bytes.NewReader(data[off:]).ReadAt(p, 0) - return n, err -} - -func (r *memReader) Read(p []byte) (n int, err error) { - n, err = bytes.NewReader(r.Bytes()[r.n:]).Read(p) - r.n += n - return n, err -} - -func (r *memReader) Close() error { - return nil -} diff --git a/libsq/core/ioz/fscache/reaper.go b/libsq/core/ioz/fscache/reaper.go deleted file mode 100644 index d801202a7..000000000 --- a/libsq/core/ioz/fscache/reaper.go +++ /dev/null @@ -1,37 +0,0 @@ -package fscache - -import "time" - -// Reaper is used to control when streams expire from the cache. -// It is called once right after loading, and then it is run -// again after every Next() period of time. -type Reaper interface { - // Returns the amount of time to wait before the next scheduled Reaping. - Next() time.Duration - - // Given a key and the last r/w times of a file, return true - // to remove the file from the cache, false to keep it. - Reap(key string, lastRead, lastWrite time.Time) bool -} - -// NewReaper returns a simple reaper which runs every "Period" -// and reaps files which are older than "expiry". -func NewReaper(expiry, period time.Duration) Reaper { - return &reaper{ - expiry: expiry, - period: period, - } -} - -type reaper struct { - period time.Duration - expiry time.Duration -} - -func (g *reaper) Next() time.Duration { - return g.period -} - -func (g *reaper) Reap(key string, lastRead, lastWrite time.Time) bool { - return lastRead.Before(time.Now().Add(-g.expiry)) -} diff --git a/libsq/core/ioz/fscache/server.go b/libsq/core/ioz/fscache/server.go deleted file mode 100644 index dba74aad3..000000000 --- a/libsq/core/ioz/fscache/server.go +++ /dev/null @@ -1,206 +0,0 @@ -package fscache - -import ( - "bytes" - "errors" - "fmt" - "io" - "net" -) - -// ListenAndServe hosts a Cache for access via NewRemote -func ListenAndServe(c Cache, addr string) error { - return (&server{c: c}).ListenAndServe(addr) -} - -// NewRemote returns a Cache run via ListenAndServe -func NewRemote(raddr string) Cache { - return &remote{raddr: raddr} -} - -type server struct { - c Cache -} - -func (s *server) ListenAndServe(addr string) error { - l, err := net.Listen("tcp", addr) - if err != nil { - return err - } - - for { - c, err := l.Accept() - if err != nil { - return err - } - - go s.Serve(c) - } -} - -const ( - actionGet = iota - actionRemove = iota - actionExists = iota - actionClean = iota -) - -func getKey(r io.Reader) string { - dec := newDecoder(r) - buf := bytes.NewBufferString("") - io.Copy(buf, dec) - return buf.String() -} - -func sendKey(w io.Writer, key string) { - enc := newEncoder(w) - enc.Write([]byte(key)) - enc.Close() -} - -func (s *server) Serve(c net.Conn) { - var action int - fmt.Fscanf(c, "%d\n", &action) - - switch action { - case actionGet: - s.get(c, getKey(c)) - case actionRemove: - s.c.Remove(getKey(c)) - case actionExists: - s.exists(c, getKey(c)) - case actionClean: - s.c.Clean() - } -} - -func (s *server) exists(c net.Conn, key string) { - if s.c.Exists(key) { - fmt.Fprintf(c, "%d\n", 1) - } else { - fmt.Fprintf(c, "%d\n", 0) - } -} - -func (s *server) get(c net.Conn, key string) { - r, w, err := s.c.Get(key) - if err != nil { - return // handle this better - } - defer r.Close() - - if w != nil { - go func() { - fmt.Fprintf(c, "%d\n", 1) - io.Copy(w, newDecoder(c)) - w.Close() - }() - } else { - fmt.Fprintf(c, "%d\n", 0) - } - - enc := newEncoder(c) - io.Copy(enc, r) - enc.Close() -} - -type remote struct { - raddr string -} - -func (rmt *remote) Get(key string) (r ReadAtCloser, w io.WriteCloser, err error) { - c, err := net.Dial("tcp", rmt.raddr) - if err != nil { - return nil, nil, err - } - fmt.Fprintf(c, "%d\n", actionGet) - sendKey(c, key) - - var i int - fmt.Fscanf(c, "%d\n", &i) - - var ch chan struct{} - - switch i { - case 0: - ch = make(chan struct{}) // close net.Conn on reader close - case 1: - ch = make(chan struct{}, 1) // two closes before net.Conn close - - w = &safeCloser{ - c: c, - ch: ch, - w: newEncoder(c), - } - default: - return nil, nil, errors.New("bad bad bad") - } - - r = &safeCloser{ - c: c, - ch: ch, - r: newDecoder(c), - } - - return r, w, nil -} - -type safeCloser struct { - c net.Conn - ch chan<- struct{} - r ReadAtCloser - w io.WriteCloser -} - -func (s *safeCloser) ReadAt(p []byte, off int64) (int, error) { - return s.r.ReadAt(p, off) -} -func (s *safeCloser) Read(p []byte) (int, error) { return s.r.Read(p) } -func (s *safeCloser) Write(p []byte) (int, error) { return s.w.Write(p) } - -// Close only closes the underlying connection when ch is full. -func (s *safeCloser) Close() (err error) { - if s.r != nil { - err = s.r.Close() - } else if s.w != nil { - err = s.w.Close() - } - - select { - case s.ch <- struct{}{}: - return err - default: - return s.c.Close() - } -} - -func (rmt *remote) Exists(key string) bool { - c, err := net.Dial("tcp", rmt.raddr) - if err != nil { - return false - } - fmt.Fprintf(c, "%d\n", actionExists) - sendKey(c, key) - var i int - fmt.Fscanf(c, "%d\n", &i) - return i == 1 -} - -func (rmt *remote) Remove(key string) error { - c, err := net.Dial("tcp", rmt.raddr) - if err != nil { - return err - } - fmt.Fprintf(c, "%d\n", actionRemove) - sendKey(c, key) - return nil -} - -func (rmt *remote) Clean() error { - c, err := net.Dial("tcp", rmt.raddr) - if err != nil { - return err - } - fmt.Fprintf(c, "%d\n", actionClean) - return nil -} diff --git a/libsq/core/ioz/fscache/stream.go b/libsq/core/ioz/fscache/stream.go deleted file mode 100644 index 9cccb2483..000000000 --- a/libsq/core/ioz/fscache/stream.go +++ /dev/null @@ -1,72 +0,0 @@ -package fscache - -import ( - "encoding/json" - "errors" - "io" -) - -type decoder interface { - Decode(interface{}) error -} - -type encoder interface { - Encode(interface{}) error -} - -type pktReader struct { - dec decoder -} - -type pktWriter struct { - enc encoder -} - -type packet struct { - Err int - Data []byte -} - -const eof = 1 - -func (t *pktReader) ReadAt(p []byte, off int64) (n int, err error) { - // TODO not implemented - return 0, errors.New("not implemented") -} - -func (t *pktReader) Read(p []byte) (int, error) { - var pkt packet - err := t.dec.Decode(&pkt) - if err != nil { - return 0, err - } - if pkt.Err == eof { - return 0, io.EOF - } - return copy(p, pkt.Data), nil -} - -func (t *pktReader) Close() error { - return nil -} - -func (t *pktWriter) Write(p []byte) (int, error) { - pkt := packet{Data: p} - err := t.enc.Encode(pkt) - if err != nil { - return 0, err - } - return len(p), nil -} - -func (t *pktWriter) Close() error { - return t.enc.Encode(packet{Err: eof}) -} - -func newEncoder(w io.Writer) io.WriteCloser { - return &pktWriter{enc: json.NewEncoder(w)} -} - -func newDecoder(r io.Reader) ReadAtCloser { - return &pktReader{dec: json.NewDecoder(r)} -} From 30a55984f0eb8019947406fa7903032944ea981a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 12:40:54 -0700 Subject: [PATCH 011/195] wip: testing fscache --- .gitignore | 1 + go.mod | 8 +++----- go.sum | 8 ++------ libsq/source/files.go | 2 +- 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 64206612d..d6dca93d5 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,7 @@ _testmain.go /scratch .envrc **/*.bench +go.work* # Some apps create temp files when editing, e.g. Excel with drivers/xlsx/testdata/~$test_header.xlsx **/testdata/~* diff --git a/go.mod b/go.mod index ebda24ba9..02a062037 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/alessio/shellescape v1.4.2 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/c2h5oh/datasize v0.0.0-20220606134207-859f65c6625b - github.com/djherbis/fscache v0.10.1 github.com/ecnepsnai/osquery v1.0.1 github.com/emirpasic/gods v1.18.1 github.com/fatih/color v1.16.0 @@ -25,8 +24,10 @@ require ( github.com/muesli/mango-cobra v1.2.0 github.com/muesli/roff v0.1.0 github.com/ncruces/go-strftime v0.1.9 + github.com/neilotoole/fscache v0.0.0-20231126193352-f79ad8b71381 github.com/neilotoole/shelleditor v0.4.1 github.com/neilotoole/slogt v1.1.0 + github.com/nightlyone/lockfile v1.0.0 github.com/otiai10/copy v1.14.0 github.com/ryboe/q v1.0.20 github.com/samber/lo v1.38.1 @@ -45,6 +46,7 @@ require ( golang.org/x/net v0.18.0 golang.org/x/sync v0.5.0 golang.org/x/term v0.14.0 + golang.org/x/text v0.14.0 ) require ( @@ -70,7 +72,6 @@ require ( github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/muesli/mango v0.2.0 // indirect github.com/muesli/mango-pflag v0.1.0 // indirect - github.com/nightlyone/lockfile v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/richardlehane/mscfb v1.0.4 // indirect github.com/richardlehane/msoleps v1.0.3 // indirect @@ -82,9 +83,6 @@ require ( github.com/xuri/nfp v0.0.0-20230919160717-d98342af3f05 // indirect golang.org/x/crypto v0.15.0 // indirect golang.org/x/sys v0.14.0 // indirect - golang.org/x/text v0.14.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect - gopkg.in/djherbis/atime.v1 v1.0.0 // indirect - gopkg.in/djherbis/stream.v1 v1.3.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 753a9ee1a..1b804d1e2 100644 --- a/go.sum +++ b/go.sum @@ -34,8 +34,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/djherbis/atime v1.1.0 h1:rgwVbP/5by8BvvjBNrbh64Qz33idKT3pSnMSJsxhi0g= github.com/djherbis/atime v1.1.0/go.mod h1:28OF6Y8s3NQWwacXc5eZTsEsiMzp7LF8MbXE+XJPdBE= -github.com/djherbis/fscache v0.10.1 h1:hDv+RGyvD+UDKyRYuLoVNbuRTnf2SrA2K3VyR1br9lk= -github.com/djherbis/fscache v0.10.1/go.mod h1:yyPYtkNnnPXsW+81lAcQS6yab3G2CRfnPLotBvtbf0c= github.com/djherbis/stream v1.4.0 h1:aVD46WZUiq5kJk55yxJAyw6Kuera6kmC3i2vEQyW/AE= github.com/djherbis/stream v1.4.0/go.mod h1:cqjC1ZRq3FFwkGmUtHwcldbnW8f0Q4YuVsGW1eAFtOk= github.com/ecnepsnai/osquery v1.0.1 h1:i96n/3uqcafKZtRYmXVNqekKbfrIm66q179mWZ/Y2Aw= @@ -126,6 +124,8 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8= github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/neilotoole/fscache v0.0.0-20231126193352-f79ad8b71381 h1:yq4OXuvSTMzvCm2m9FlpUnL8PbVQGW4qv+s8uRRtxK8= +github.com/neilotoole/fscache v0.0.0-20231126193352-f79ad8b71381/go.mod h1:GdelWtWN0Gbf2uE+rEZ4GinWtfV6PobHgdrQ4IrB504= github.com/neilotoole/shelleditor v0.4.1 h1:74LEw2mVo3jtNw2BjII6RSss9DXgEqAbmCQDDiJvzO0= github.com/neilotoole/shelleditor v0.4.1/go.mod h1:QanOZN4syDMp/L0SKwZb47Mh49mvLWX3ja5YfbYDDjo= github.com/neilotoole/slogt v1.1.0 h1:c7qE92sq+V0yvCuaxph+RQ2jOKL61c4hqS1Bv9W7FZE= @@ -268,10 +268,6 @@ golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6f gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/djherbis/atime.v1 v1.0.0 h1:eMRqB/JrLKocla2PBPKgQYg/p5UG4L6AUAs92aP7F60= -gopkg.in/djherbis/atime.v1 v1.0.0/go.mod h1:hQIUStKmJfvf7xdh/wtK84qe+DsTV5LnA9lzxxtPpJ8= -gopkg.in/djherbis/stream.v1 v1.3.1 h1:uGfmsOY1qqMjQQphhRBSGLyA9qumJ56exkRu9ASTjCw= -gopkg.in/djherbis/stream.v1 v1.3.1/go.mod h1:aEV8CBVRmSpLamVJfM903Npic1IKmb2qS30VAZ+sssg= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/libsq/source/files.go b/libsq/source/files.go index 5557656ce..238e5b5c7 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -2,6 +2,7 @@ package source import ( "context" + "github.com/neilotoole/fscache" "io" "log/slog" "mime" @@ -11,7 +12,6 @@ import ( "sync" "time" - "github.com/djherbis/fscache" "github.com/h2non/filetype" "github.com/h2non/filetype/matchers" "golang.org/x/sync/errgroup" From 57cf440f4335f835ad9f8de5a090bc5ad92c7dae Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 12:44:34 -0700 Subject: [PATCH 012/195] formatting --- drivers/userdriver/xmlud/xmlimport_test.go | 5 ++--- libsq/driver/sources.go | 3 ++- libsq/source/files.go | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/drivers/userdriver/xmlud/xmlimport_test.go b/drivers/userdriver/xmlud/xmlimport_test.go index e2bc680ca..607f7d966 100644 --- a/drivers/userdriver/xmlud/xmlimport_test.go +++ b/drivers/userdriver/xmlud/xmlimport_test.go @@ -4,9 +4,6 @@ import ( "bytes" "testing" - "github.com/neilotoole/sq/libsq/source" - "github.com/neilotoole/sq/libsq/source/drivertype" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -14,6 +11,8 @@ import ( "github.com/neilotoole/sq/drivers/userdriver/xmlud" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/testsrc" diff --git a/libsq/driver/sources.go b/libsq/driver/sources.go index b90bbcd0e..2cb6144e7 100644 --- a/libsq/driver/sources.go +++ b/libsq/driver/sources.go @@ -10,6 +10,8 @@ import ( "sync" "time" + "github.com/nightlyone/lockfile" + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" @@ -19,7 +21,6 @@ import ( "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/retry" "github.com/neilotoole/sq/libsq/source" - "github.com/nightlyone/lockfile" ) var ( diff --git a/libsq/source/files.go b/libsq/source/files.go index 238e5b5c7..ef75bc247 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -2,7 +2,6 @@ package source import ( "context" - "github.com/neilotoole/fscache" "io" "log/slog" "mime" @@ -16,6 +15,8 @@ import ( "github.com/h2non/filetype/matchers" "golang.org/x/sync/errgroup" + "github.com/neilotoole/fscache" + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" From 901b81f71a47bd69aba8c792361d62d84e497c16 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 14:32:07 -0700 Subject: [PATCH 013/195] wip: broken --- cli/source.go | 2 +- drivers/csv/csv.go | 6 ++-- drivers/csv/csv_test.go | 6 ++-- drivers/csv/detect_type.go | 2 +- drivers/csv/ingest.go | 2 +- drivers/json/import_json.go | 8 ++--- drivers/json/import_jsona.go | 6 ++-- drivers/json/import_jsonl.go | 4 +-- drivers/json/import_test.go | 7 ++-- drivers/json/json.go | 8 ++--- drivers/json/json_test.go | 2 +- drivers/userdriver/userdriver.go | 6 ++-- drivers/xlsx/database.go | 4 +-- drivers/xlsx/detect.go | 2 +- drivers/xlsx/xlsx.go | 2 +- libsq/source/files.go | 59 +++++++++++++++++--------------- libsq/source/files_test.go | 14 ++++---- libsq/source/internal_test.go | 4 +-- testh/testh_test.go | 2 +- 19 files changed, 76 insertions(+), 70 deletions(-) diff --git a/cli/source.go b/cli/source.go index dde97bab7..fcb177d53 100644 --- a/cli/source.go +++ b/cli/source.go @@ -181,7 +181,7 @@ func checkStdinSource(ctx context.Context, ru *run.Run) (*source.Source, error) } } - err = ru.Files.AddStdin(f) + err = ru.Files.AddStdin(ctx, f) if err != nil { return nil, err } diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index e0eac4dac..ae5c60c51 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -106,11 +106,11 @@ func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { } // Ping implements driver.Driver. -func (d *driveri) Ping(_ context.Context, src *source.Source) error { +func (d *driveri) Ping(ctx context.Context, src *source.Source) error { // FIXME: Does Ping calling d.files.Open cause a full read? // We probably just want to check that the file exists // or is accessible. - r, err := d.files.Open(src) + r, err := d.files.Open(ctx, src) if err != nil { return err } @@ -174,7 +174,7 @@ func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return nil, err } - md.Size, err = p.files.Size(p.src) + md.Size, err = p.files.Size(ctx, p.src) if err != nil { return nil, err } diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index 41fdff298..cdd34ed86 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -3,6 +3,7 @@ package csv_test import ( "context" stdcsv "encoding/csv" + "fmt" "math/rand" "os" "path/filepath" @@ -341,7 +342,7 @@ func TestDatetime(t *testing.T) { // TestIngestLargeCSV generates a large CSV file. // At count = 5000000, the generated file is ~500MB. func TestGenerateLargeCSV(t *testing.T) { - t.Skip() + //t.Skip() const count = 5000000 // Generates ~500MB file start := time.Now() header := []string{ @@ -383,7 +384,8 @@ func TestGenerateLargeCSV(t *testing.T) { rec[3] = strconv.Itoa(rand.Intn(10)) // staff_id rec[4] = strconv.Itoa(i + 3) // rental_id, always unique f64 := amount.InexactFloat64() - rec[5] = p.Sprintf("%.2f", f64) // amount + //rec[5] = p.Sprintf("%.2f", f64) // amount + rec[5] = fmt.Sprintf("%.2f", f64) // amount amount = amount.Add(decimal.New(33, -2)) rec[6] = timez.TimestampUTC(paymentUTC) // payment_date paymentUTC = paymentUTC.Add(time.Minute) diff --git a/drivers/csv/detect_type.go b/drivers/csv/detect_type.go index 1acbf730d..7fd7dd29d 100644 --- a/drivers/csv/detect_type.go +++ b/drivers/csv/detect_type.go @@ -38,7 +38,7 @@ func detectType(ctx context.Context, typ drivertype.Type, ) (detected drivertype.Type, score float32, err error) { log := lg.FromContext(ctx) var r io.ReadCloser - r, err = openFn() + r, err = openFn(ctx) if err != nil { return drivertype.None, 0, errz.Err(err) } diff --git a/drivers/csv/ingest.go b/drivers/csv/ingest.go index fc2883e0f..6e8cf99f9 100644 --- a/drivers/csv/ingest.go +++ b/drivers/csv/ingest.go @@ -57,7 +57,7 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu var err error var r io.ReadCloser - r, err = openFn() + r, err = openFn(ctx) if err != nil { return err } diff --git a/drivers/json/import_json.go b/drivers/json/import_json.go index 1a1422a94..beb9e5a45 100644 --- a/drivers/json/import_json.go +++ b/drivers/json/import_json.go @@ -23,7 +23,7 @@ func DetectJSON(sampleSize int) source.DriverDetectFunc { ) { log := lg.FromContext(ctx) var r1, r2 io.ReadCloser - r1, err = openFn() + r1, err = openFn(ctx) if err != nil { return drivertype.None, 0, errz.Err(err) } @@ -46,7 +46,7 @@ func DetectJSON(sampleSize int) source.DriverDetectFunc { return drivertype.None, 0, nil case leftBrace: // The input is a single JSON object - r2, err = openFn() + r2, err = openFn(ctx) // buf gets a copy of what is read from r2 buf := &buffer{} @@ -92,7 +92,7 @@ func DetectJSON(sampleSize int) source.DriverDetectFunc { // The input is one or more JSON objects inside an array } - r2, err = openFn() + r2, err = openFn(ctx) if err != nil { return drivertype.None, 0, errz.Err(err) } @@ -135,7 +135,7 @@ func DetectJSON(sampleSize int) source.DriverDetectFunc { func importJSON(ctx context.Context, job importJob) error { log := lg.FromContext(ctx) - r, err := job.openFn() + r, err := job.openFn(ctx) if err != nil { return err } diff --git a/drivers/json/import_jsona.go b/drivers/json/import_jsona.go index 529b7b526..91feff792 100644 --- a/drivers/json/import_jsona.go +++ b/drivers/json/import_jsona.go @@ -29,7 +29,7 @@ func DetectJSONA(sampleSize int) source.DriverDetectFunc { ) { log := lg.FromContext(ctx) var r io.ReadCloser - r, err = openFn() + r, err = openFn(ctx) if err != nil { return drivertype.None, 0, errz.Err(err) } @@ -98,7 +98,7 @@ func DetectJSONA(sampleSize int) source.DriverDetectFunc { func importJSONA(ctx context.Context, job importJob) error { log := lg.FromContext(ctx) - predictR, err := job.openFn() + predictR, err := job.openFn(ctx) if err != nil { return errz.Err(err) } @@ -136,7 +136,7 @@ func importJSONA(ctx context.Context, job importJob) error { return err } - r, err := job.openFn() + r, err := job.openFn(ctx) if err != nil { return errz.Err(err) } diff --git a/drivers/json/import_jsonl.go b/drivers/json/import_jsonl.go index 803c79d1a..b029de1d7 100644 --- a/drivers/json/import_jsonl.go +++ b/drivers/json/import_jsonl.go @@ -23,7 +23,7 @@ func DetectJSONL(sampleSize int) source.DriverDetectFunc { ) { log := lg.FromContext(ctx) var r io.ReadCloser - r, err = openFn() + r, err = openFn(ctx) if err != nil { return drivertype.None, 0, errz.Err(err) } @@ -86,7 +86,7 @@ func DetectJSONL(sampleSize int) source.DriverDetectFunc { func importJSONL(ctx context.Context, job importJob) error { //nolint:gocognit log := lg.FromContext(ctx) - r, err := job.openFn() + r, err := job.openFn(ctx) if err != nil { return err } diff --git a/drivers/json/import_test.go b/drivers/json/import_test.go index debd5f5a7..20307778f 100644 --- a/drivers/json/import_test.go +++ b/drivers/json/import_test.go @@ -2,6 +2,7 @@ package json_test import ( "bytes" + "context" stdj "encoding/json" "io" "os" @@ -74,12 +75,12 @@ func TestImportJSONL_Flat(t *testing.T) { tc := tc t.Run(tutil.Name(i, tc.fpath, tc.input), func(t *testing.T) { - openFn := func() (io.ReadCloser, error) { + openFn := func(ctx context.Context) (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(tc.input)), nil } if tc.fpath != "" { - openFn = func() (io.ReadCloser, error) { + openFn = func(ctx context.Context) (io.ReadCloser, error) { return os.Open(filepath.Join("testdata", tc.fpath)) } } @@ -105,7 +106,7 @@ func TestImportJSONL_Flat(t *testing.T) { } func TestImportJSON_Flat(t *testing.T) { - openFn := func() (io.ReadCloser, error) { + openFn := func(context.Context) (io.ReadCloser, error) { return os.Open("testdata/actor.json") } diff --git a/drivers/json/json.go b/drivers/json/json.go index 8e04c5db1..22c03e53a 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -97,7 +97,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er p := &pool{log: d.log, src: src, clnup: cleanup.New(), files: d.files} - r, err := d.files.Open(src) + r, err := d.files.Open(ctx, src) if err != nil { return nil, err } @@ -147,10 +147,10 @@ func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { } // Ping implements driver.Driver. -func (d *driveri) Ping(_ context.Context, src *source.Source) error { +func (d *driveri) Ping(ctx context.Context, src *source.Source) error { d.log.Debug("Ping source", lga.Src, src) - r, err := d.files.Open(src) + r, err := d.files.Open(ctx, src) if err != nil { return err } @@ -215,7 +215,7 @@ func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return nil, err } - md.Size, err = p.files.Size(p.src) + md.Size, err = p.files.Size(ctx, p.src) if err != nil { return nil, err } diff --git a/drivers/json/json_test.go b/drivers/json/json_test.go index 5e7e35559..7814bb742 100644 --- a/drivers/json/json_test.go +++ b/drivers/json/json_test.go @@ -93,7 +93,7 @@ func TestDriverDetectorFuncs(t *testing.T) { tc := tc t.Run(tutil.Name(tc.fn, tc.fname), func(t *testing.T) { - openFn := func() (io.ReadCloser, error) { return os.Open(filepath.Join("testdata", tc.fname)) } + openFn := func(ctx context.Context) (io.ReadCloser, error) { return os.Open(filepath.Join("testdata", tc.fname)) } detectFn := detectFns[tc.fn] ctx := lg.NewContext(context.Background(), slogt.New(t)) diff --git a/drivers/userdriver/userdriver.go b/drivers/userdriver/userdriver.go index c9bff94d6..e9c5cafb6 100644 --- a/drivers/userdriver/userdriver.go +++ b/drivers/userdriver/userdriver.go @@ -86,7 +86,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er clnup := cleanup.New() - r, err := d.files.Open(src) + r, err := d.files.Open(ctx, src) if err != nil { return nil, err } @@ -123,13 +123,13 @@ func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { } // Ping implements driver.Driver. -func (d *driveri) Ping(_ context.Context, src *source.Source) error { +func (d *driveri) Ping(ctx context.Context, src *source.Source) error { d.log.Debug("Ping source", lga.Driver, d.typ, lga.Src, src, ) - r, err := d.files.Open(src) + r, err := d.files.Open(ctx, src) if err != nil { return err } diff --git a/drivers/xlsx/database.go b/drivers/xlsx/database.go index a516c2e3e..7d2d12673 100644 --- a/drivers/xlsx/database.go +++ b/drivers/xlsx/database.go @@ -57,7 +57,7 @@ func (p *pool) doIngest(ctx context.Context, includeSheetNames []string) error { // has the source's options on it. ctx = options.NewContext(ctx, options.Merge(options.FromContext(ctx), p.src.Options)) - r, err := p.files.Open(p.src) + r, err := p.files.Open(ctx, p.src) if err != nil { return err } @@ -122,7 +122,7 @@ func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou } md.FQName = md.Name - if md.Size, err = p.files.Size(p.src); err != nil { + if md.Size, err = p.files.Size(ctx, p.src); err != nil { return nil, err } diff --git a/drivers/xlsx/detect.go b/drivers/xlsx/detect.go index 276997f59..b3fcfe03d 100644 --- a/drivers/xlsx/detect.go +++ b/drivers/xlsx/detect.go @@ -25,7 +25,7 @@ func DetectXLSX(ctx context.Context, openFn source.FileOpenFunc) (detected drive ) { log := lg.FromContext(ctx) var r io.ReadCloser - r, err = openFn() + r, err = openFn(ctx) if err != nil { return drivertype.None, 0, errz.Err(err) } diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index d32d14179..da74b6903 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -104,7 +104,7 @@ func (d *Driver) ValidateSource(src *source.Source) (*source.Source, error) { func (d *Driver) Ping(ctx context.Context, src *source.Source) (err error) { log := lg.FromContext(ctx) - r, err := d.files.Open(src) + r, err := d.files.Open(ctx, src) if err != nil { return err } diff --git a/libsq/source/files.go b/libsq/source/files.go index ef75bc247..d046cef76 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -74,8 +74,11 @@ func (fs *Files) AddDriverDetectors(detectFns ...DriverDetectFunc) { // Size returns the file size of src.Location. This exists // as a convenience function and something of a replacement // for using os.Stat to get the file size. -func (fs *Files) Size(src *Source) (size int64, err error) { - r, err := fs.Open(src) +// +// FIXME: This is a terrible way to get the size. It currently +// reads all the bytes. Awful. +func (fs *Files) Size(ctx context.Context, src *Source) (size int64, err error) { + r, err := fs.Open(ctx, src) if err != nil { return 0, err } @@ -97,12 +100,12 @@ func (fs *Files) Size(src *Source) (size int64, err error) { // // REVISIT: it's possible we'll ditch AddStdin and TypeStdin // in some future version; this mechanism is a stopgap. -func (fs *Files) AddStdin(f *os.File) error { +func (fs *Files) AddStdin(ctx context.Context, f *os.File) error { fs.mu.Lock() defer fs.mu.Unlock() // We don't need r, but we're responsible for closing it. - r, err := fs.addFile(f, StdinHandle) // f is closed by addFile + r, err := fs.addFile(ctx, f, StdinHandle) // f is closed by addFile if err != nil { return err } @@ -133,8 +136,9 @@ func (fs *Files) TypeStdin(ctx context.Context) (drivertype.Type, error) { // add file copies f to fs's cache, returning a reader which the // caller is responsible for closing. f is closed by this method. -func (fs *Files) addFile(f *os.File, key string) (fscache.ReadAtCloser, error) { - fs.log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) +func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { + log := lg.FromContext(ctx) + log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) r, w, err := fs.fcache.Get(key) if err != nil { return nil, errz.Err(err) @@ -150,17 +154,16 @@ func (fs *Files) addFile(f *os.File, key string) (fscache.ReadAtCloser, error) { // everything is held up until f is fully copied. Hopefully we can // do something with fscache so that the readers returned from // fscache can lazily read from f. - _, err = io.Copy(w, f) - if err != nil { + if err = fscache.FillWriterAsync(ctx, log, w, f, true); err != nil { lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) return nil, errz.Err(err) } - err = errz.Combine(w.Close(), f.Close()) - if err != nil { - lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - return nil, err - } + //err = errz.Combine(w.Close(), f.Close()) + //if err != nil { + // lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) + // return nil, err + //} return r, nil } @@ -241,23 +244,23 @@ func (fs *Files) Filepath(_ context.Context, src *Source) (string, error) { // Open returns a new io.ReadCloser for src.Location. // If src.Handle is StdinHandle, AddStdin must first have // been invoked. The caller must close the reader. -func (fs *Files) Open(src *Source) (io.ReadCloser, error) { +func (fs *Files) Open(ctx context.Context, src *Source) (io.ReadCloser, error) { fs.mu.Lock() defer fs.mu.Unlock() - return fs.newReader(src.Location) + return fs.newReader(ctx, src.Location) } // OpenFunc returns a func that invokes fs.Open for src.Location. -func (fs *Files) OpenFunc(src *Source) func() (io.ReadCloser, error) { - return func() (io.ReadCloser, error) { - return fs.Open(src) +func (fs *Files) OpenFunc(src *Source) func(ctx context.Context) (io.ReadCloser, error) { + return func(ctx context.Context) (io.ReadCloser, error) { + return fs.Open(ctx, src) } } // ReadAll is a convenience method to read the bytes of a source. -func (fs *Files) ReadAll(src *Source) ([]byte, error) { - r, err := fs.newReader(src.Location) +func (fs *Files) ReadAll(ctx context.Context, src *Source) ([]byte, error) { + r, err := fs.newReader(ctx, src.Location) if err != nil { return nil, err } @@ -275,7 +278,7 @@ func (fs *Files) ReadAll(src *Source) ([]byte, error) { return data, nil } -func (fs *Files) newReader(loc string) (io.ReadCloser, error) { +func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, error) { if loc == StdinHandle { r, w, err := fs.fcache.Get(StdinHandle) if err != nil { @@ -290,13 +293,13 @@ func (fs *Files) newReader(loc string) (io.ReadCloser, error) { if !fs.fcache.Exists(loc) { // cache miss - f, err := fs.openLocation(loc) + f, err := fs.openLocation(ctx, loc) if err != nil { return nil, err } // Note that addFile closes f - r, err := fs.addFile(f, loc) + r, err := fs.addFile(ctx, f, loc) if err != nil { return nil, err } @@ -313,7 +316,7 @@ func (fs *Files) newReader(loc string) (io.ReadCloser, error) { // openLocation returns a file for loc. It is the caller's // responsibility to close the returned file. -func (fs *Files) openLocation(loc string) (*os.File, error) { +func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) { var fpath string var ok bool var err error @@ -453,11 +456,11 @@ func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Typ } resultCh := make(chan result, len(fs.detectFns)) - openFn := func() (io.ReadCloser, error) { + openFn := func(ctx context.Context) (io.ReadCloser, error) { fs.mu.Lock() defer fs.mu.Unlock() - return fs.newReader(loc) + return fs.newReader(ctx, loc) } select { @@ -516,7 +519,7 @@ func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Typ // FileOpenFunc returns a func that opens a ReadCloser. The caller // is responsible for closing the returned ReadCloser. -type FileOpenFunc func() (io.ReadCloser, error) +type FileOpenFunc func(ctx context.Context) (io.ReadCloser, error) // DriverDetectFunc interrogates a byte stream to determine // the source driver type. A score is returned indicating @@ -538,7 +541,7 @@ func DetectMagicNumber(ctx context.Context, openFn FileOpenFunc, ) (detected drivertype.Type, score float32, err error) { log := lg.FromContext(ctx) var r io.ReadCloser - r, err = openFn() + r, err = openFn(ctx) if err != nil { return drivertype.None, 0, errz.Err(err) } diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index f428eeb0c..fae2e2828 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -131,7 +131,7 @@ func TestDetectMagicNumber(t *testing.T) { tc := tc t.Run(filepath.Base(tc.loc), func(t *testing.T) { - rFn := func() (io.ReadCloser, error) { return os.Open(tc.loc) } + rFn := func(ctx context.Context) (io.ReadCloser, error) { return os.Open(tc.loc) } ctx := lg.NewContext(context.Background(), slogt.New(t)) @@ -166,7 +166,7 @@ func TestFiles_NewReader(t *testing.T) { for i := 0; i < 1000; i++ { g.Go(func() error { - r, gErr := fs.Open(src) + r, gErr := fs.Open(ctx, src) require.NoError(t, gErr) b, gErr := io.ReadAll(r) @@ -202,7 +202,7 @@ func TestFiles_Stdin(t *testing.T) { f, err := os.Open(tc.fpath) require.NoError(t, err) - err = fs.AddStdin(f) // f is closed by AddStdin + err = fs.AddStdin(th.Context, f) // f is closed by AddStdin require.NoError(t, err) typ, err := fs.TypeStdin(th.Context) @@ -227,7 +227,7 @@ func TestFiles_Stdin_ErrorWrongOrder(t *testing.T) { f, err := os.Open(proj.Abs(sakila.PathCSVActor)) require.NoError(t, err) - require.NoError(t, fs.AddStdin(f)) // AddStdin closes f + require.NoError(t, fs.AddStdin(th.Context, f)) // AddStdin closes f typ, err = fs.TypeStdin(th.Context) require.NoError(t, err) require.Equal(t, csv.TypeCSV, typ) @@ -245,7 +245,7 @@ func TestFiles_Size(t *testing.T) { th := testh.New(t) fs := th.Files() - gotSize, err := fs.Size(&source.Source{ + gotSize, err := fs.Size(th.Context, &source.Source{ Handle: stringz.UniqSuffix("@h"), Location: f.Name(), }) @@ -255,9 +255,9 @@ func TestFiles_Size(t *testing.T) { f2, err := os.Open(proj.Abs(sakila.PathCSVActor)) require.NoError(t, err) // Verify that this works with @stdin as well - require.NoError(t, fs.AddStdin(f2)) + require.NoError(t, fs.AddStdin(th.Context, f2)) - gotSize2, err := fs.Size(&source.Source{ + gotSize2, err := fs.Size(th.Context, &source.Source{ Handle: "@stdin", Location: "@stdin", }) diff --git a/libsq/source/internal_test.go b/libsq/source/internal_test.go index f32231a37..051db7abd 100644 --- a/libsq/source/internal_test.go +++ b/libsq/source/internal_test.go @@ -38,7 +38,7 @@ func TestFiles_Open(t *testing.T) { Location: proj.Abs(testsrc.PathXLSXTestHeader), } - f, err := fs.openLocation(src1.Location) + f, err := fs.openLocation(ctx, src1.Location) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, f.Close()) }) require.Equal(t, src1.Location, f.Name()) @@ -47,7 +47,7 @@ func TestFiles_Open(t *testing.T) { Location: sakila.URLActorCSV, } - f2, err := fs.openLocation(src2.Location) + f2, err := fs.openLocation(ctx, src2.Location) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, f2.Close()) }) diff --git a/testh/testh_test.go b/testh/testh_test.go index 271a89acb..213606839 100644 --- a/testh/testh_test.go +++ b/testh/testh_test.go @@ -145,7 +145,7 @@ func TestHelper_Files(t *testing.T) { for i := 0; i < 1000; i++ { g.Go(func() error { - r, fErr := fs.Open(src) + r, fErr := fs.Open(th.Context, src) require.NoError(t, fErr) defer func() { require.NoError(t, r.Close()) }() From 3f3f5e911c404dd3362621675b32f1e5e74c136c Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 18:25:01 -0700 Subject: [PATCH 014/195] wip: progress bar --- cli/cli.go | 1 + cli/cmd_xtest.go | 233 +++++++++++++++++++ go.mod | 3 + go.sum | 6 + libsq/core/progress/progress.go | 70 ++++++ libsq/core/progress/progress_test.go | 1 + libsq/core/progress/progressio/progressio.go | 122 ++++++++++ 7 files changed, 436 insertions(+) create mode 100644 cli/cmd_xtest.go create mode 100644 libsq/core/progress/progress.go create mode 100644 libsq/core/progress/progress_test.go create mode 100644 libsq/core/progress/progressio/progressio.go diff --git a/cli/cli.go b/cli/cli.go index 7a794e35b..07211a1c4 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -229,6 +229,7 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { addCmd(ru, rootCmd, newCompletionCmd()) addCmd(ru, rootCmd, newVersionCmd()) addCmd(ru, rootCmd, newManCmd()) + addCmd(ru, rootCmd, newXTestCmd()) return rootCmd } diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go new file mode 100644 index 000000000..161be3e19 --- /dev/null +++ b/cli/cmd_xtest.go @@ -0,0 +1,233 @@ +package cli + +import ( + "context" + "fmt" + "github.com/neilotoole/sq/cli/buildinfo" + "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/cli/hostinfo" + "github.com/neilotoole/sq/cli/output" + "github.com/neilotoole/sq/cli/run" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/spf13/cobra" + "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" + "io" + "math/rand" + "time" +) + +func newXTestCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "xtest", + Short: "Execute some internal tests", + Hidden: true, + RunE: execXTestMbp, + } + + cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) + cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage) + cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage) + + return cmd +} + +func execXTestMbp(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + ru := run.FromContext(ctx) + //pb := progress.New(ctx, ru.ErrOut, 1*time.Second) + + fmt.Fprintln(ru.Out, "Hello, world!") + + if err := doBigRead2(ctx, ru.Writers.Printing, ru.ErrOut); err != nil { + return err + } + + //_ = pb + + return ru.Writers.Version.Version(buildinfo.Get(), buildinfo.Get().Version, hostinfo.Get()) + +} + +func execXTestMbpIndeterminate(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + ru := run.FromContext(ctx) + + //bar.Abort(true) + // wait for our bar to complete and flush + //p.Wait() + //bar.Abort(true) + //p.Shutdown() + + return ru.Writers.Version.Version(buildinfo.Get(), buildinfo.Get().Version, hostinfo.Get()) +} + +func doBigReadOg(ctx context.Context, errOut io.Writer) error { + pb := progress.New(ctx, errOut, 1*time.Second) + + //total := 50 + //p := mpb.New( + // mpb.WithOutput(ru.Out), + // mpb.WithWidth(64), + // mpb.WithRenderDelay(after2(1*time.Second)), + //) + p := pb.P + + //bar := p.New( + // int64(0), + // mpb.BarStyle(), + // mpb.PrependDecorators(decor.Name("huzzah")), + // mpb.BarRemoveOnComplete(), + //) + + bar := p.AddBar(0, + mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f / % .1f")), + //mpb.AppendDecorators(decor.Percentage()), + mpb.BarRemoveOnComplete(), + ) + + maxSleep := 100 * time.Millisecond + + jr := &junkReader{limit: 1000000} + b := make([]byte, 1024) + +LOOP: + for { + select { + case <-ctx.Done(): + + break LOOP + default: + } + + n, err := jr.Read(b) + if err != nil { + //bar.SetTotal(-1, true) + if err == io.EOF { + // triggering complete event now + bar.SetTotal(-1, true) + break + } + break + } + // increment methods won't trigger complete event because bar was constructed with total = 0 + bar.IncrBy(n) + // following call is not required, it's called to show some progress instead of an empty bar + bar.SetTotal(bar.Current()+2048, false) + time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) + } + + p.Wait() + return nil +} + +func doBigRead2(ctx context.Context, pr *output.Printing, errOut io.Writer) error { + pb := progress.New(ctx, errOut, 1*time.Second) + + //total := 50 + //p := mpb.New( + // mpb.WithOutput(ru.Out), + // mpb.WithWidth(64), + // mpb.WithRenderDelay(after2(1*time.Second)), + //) + p := pb.P + + //bar := p.New( + // int64(0), + // mpb.BarStyle(), + // mpb.PrependDecorators(decor.Name("huzzah")), + // mpb.BarRemoveOnComplete(), + //) + + s := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") + spinnerStyle := s.Meta(func(s string) string { + return pr.Active.Sprint(s) + //return "\033[31m" + s + "\033[0m" // red + }) + + bar := p.New(0, + spinnerStyle, + //mpb.PrependDecorators(), + //mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f / % .1f")), + //mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f")), + mpb.PrependDecorators( + progress.ColorMeta(decor.Name("Ingesting data..."), pr.Faint), + ), + //mpb.AppendDecorators(decor.Percentage()), + mpb.AppendDecorators( + progress.ColorMeta(decor.Current(decor.SizeB1024(0), "% .1f"), pr.Faint), + ), + //mpb.AppendDecorators( + // // replace ETA decorator with "done" message, OnComplete event + // decor.OnComplete( + // // ETA decorator with ewma age of 30 + // decor.EwmaETA(decor.ET_STYLE_GO, 30), "done", + // ), + //), + mpb.BarRemoveOnComplete(), + ) + + maxSleep := 100 * time.Millisecond + + jr := &junkReader{limit: 100000} + b := make([]byte, 1024) + + //start := time.Now() + +LOOP: + for { + //bar.EwmaIncrement(time.Since(start)) + select { + case <-ctx.Done(): + bar.SetTotal(-1, true) + break LOOP + default: + } + + n, err := jr.Read(b) + if err != nil { + bar.SetTotal(-1, true) + if err == io.EOF { + // triggering complete event now + //bar.SetTotal(-1, true) + break + } + break + } + // increment methods won't trigger complete event because bar was constructed with total = 0 + bar.IncrBy(n) + // following call is not required, it's called to show some progress instead of an empty bar + bar.SetTotal(bar.Current()+2048, false) + time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) + } + + p.Wait() + return nil +} + +func makeStream(limit int) func() (int, error) { + return func() (int, error) { + if limit <= 0 { + return 0, io.EOF + } + limit-- + return rand.Intn(1024) + 1, nil + } +} + +type junkReader struct { + limit int + count int +} + +func (r *junkReader) Read(p []byte) (n int, err error) { + if r.count >= r.limit { + return 0, io.EOF + } + + amount, err := rand.Read(p) + r.count += amount + return amount, err + + //return rand.Intn(1024) + 1, nil +} diff --git a/go.mod b/go.mod index 02a062037..b67382873 100644 --- a/go.mod +++ b/go.mod @@ -53,6 +53,8 @@ require ( github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.1 // indirect + github.com/VividCortex/ewma v1.2.0 // indirect + github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/djherbis/atime v1.1.0 // indirect github.com/djherbis/stream v1.4.0 // indirect @@ -79,6 +81,7 @@ require ( github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/segmentio/asm v1.1.3 // indirect github.com/spf13/cast v1.5.1 // indirect + github.com/vbauerster/mpb/v8 v8.6.2 // indirect github.com/xuri/efp v0.0.0-20231025114914-d1ff6096ae53 // indirect github.com/xuri/nfp v0.0.0-20230919160717-d98342af3f05 // indirect golang.org/x/crypto v0.15.0 // indirect diff --git a/go.sum b/go.sum index 1b804d1e2..891c94870 100644 --- a/go.sum +++ b/go.sum @@ -19,6 +19,10 @@ github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0 github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= +github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1ow= +github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= +github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8= +github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= github.com/alessio/shellescape v1.4.2 h1:MHPfaU+ddJ0/bYWpgIeUnQUqKrlJ1S7BfEYPM4uEoM0= github.com/alessio/shellescape v1.4.2/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= @@ -183,6 +187,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/vbauerster/mpb/v8 v8.6.2 h1:9EhnJGQRtvgDVCychJgR96EDCOqgg2NsMuk5JUcX4DA= +github.com/vbauerster/mpb/v8 v8.6.2/go.mod h1:oVJ7T+dib99kZ/VBjoBaC8aPXiSAihnzuKmotuihyFo= github.com/xo/dburl v0.18.2 h1:9xqcVf+JEV7bcUa1OjCsoax06roohYFdye6xkvBKo50= github.com/xo/dburl v0.18.2/go.mod h1:B7/G9FGungw6ighV8xJNwWYQPMfn3gsi2sn5SE8Bzco= github.com/xuri/efp v0.0.0-20230802181842-ad255f2331ca/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI= diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go new file mode 100644 index 000000000..fe81a47fe --- /dev/null +++ b/libsq/core/progress/progress.go @@ -0,0 +1,70 @@ +package progress + +import ( + "context" + "github.com/fatih/color" + "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" + "io" + "time" +) + +type runKey struct{} + +// NewContext returns ctx with prog added as a value. +func NewContext(ctx context.Context, prog *Progress) context.Context { + if ctx == nil { + ctx = context.Background() + } + + return context.WithValue(ctx, runKey{}, prog) +} + +// FromContext extracts the Progress added to ctx via NewContext. +func FromContext(ctx context.Context) *Progress { + return ctx.Value(runKey{}).(*Progress) +} + +type Progress struct { + P *mpb.Progress +} + +func New(ctx context.Context, out io.Writer, delay time.Duration) *Progress { + p := mpb.NewWithContext(ctx, + mpb.WithOutput(out), + mpb.WithWidth(64), + mpb.WithRenderDelay(renderDelay(delay)), + ) + + return &Progress{P: p} +} + +func renderDelay(d time.Duration) <-chan struct{} { + ch := make(chan struct{}) + time.AfterFunc(d, func() { + close(ch) + }) + return ch +} + +func NewProxyBar(p *mpb.Progress) { + bar := p.New(0, + mpb.BarStyle().Rbound("|"), + mpb.PrependDecorators( + decor.Counters(decor.SizeB1024(0), "% .2f / % .2f"), + ), + mpb.AppendDecorators( + decor.EwmaETA(decor.ET_STYLE_GO, 30), + decor.Name(" ] "), + decor.EwmaSpeed(decor.SizeB1024(0), "% .2f", 30), + ), + ) + + _ = bar +} + +func ColorMeta(decorator decor.Decorator, c *color.Color) decor.Decorator { + return decor.Meta(decorator, func(s string) string { + return c.Sprint(s) + }) +} diff --git a/libsq/core/progress/progress_test.go b/libsq/core/progress/progress_test.go new file mode 100644 index 000000000..9922214f2 --- /dev/null +++ b/libsq/core/progress/progress_test.go @@ -0,0 +1 @@ +package progress_test diff --git a/libsq/core/progress/progressio/progressio.go b/libsq/core/progress/progressio/progressio.go new file mode 100644 index 000000000..ce97b4d59 --- /dev/null +++ b/libsq/core/progress/progressio/progressio.go @@ -0,0 +1,122 @@ +/* +Copyright 2018 Olivier Mengué + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package progressio provides [io.Writer] and [io.Reader] that stop accepting/providing +// data when an attached context is canceled, and know how to interact +// with a progress bar. +// +// The code is lifted from github.com/dolmen-go/contextio. +package progressio + +import ( + "context" + "io" +) + +type writer struct { + ctx context.Context + w io.Writer +} + +type copier struct { + writer +} + +// NewWriter wraps an [io.Writer] to handle context cancellation. +// +// Context state is checked BEFORE every Write. +// +// The returned Writer also implements [io.ReaderFrom] to allow [io.Copy] to select +// the best strategy while still checking the context state before every chunk transfer. +func NewWriter(ctx context.Context, w io.Writer) io.Writer { + if w, ok := w.(*copier); ok && ctx == w.ctx { + return w + } + return &copier{writer{ctx: ctx, w: w}} +} + +// Write implements [io.Writer], but with context awareness. +func (w *writer) Write(p []byte) (n int, err error) { + select { + case <-w.ctx.Done(): + return 0, w.ctx.Err() + default: + return w.w.Write(p) + } +} + +type reader struct { + ctx context.Context + r io.Reader +} + +// NewReader wraps an [io.Reader] to handle context cancellation. +// +// Context state is checked BEFORE every Read. +func NewReader(ctx context.Context, r io.Reader) io.Reader { + if r, ok := r.(*reader); ok && ctx == r.ctx { + return r + } + return &reader{ctx: ctx, r: r} +} + +func (r *reader) Read(p []byte) (n int, err error) { + select { + case <-r.ctx.Done(): + return 0, r.ctx.Err() + default: + return r.r.Read(p) + } +} + +// ReadFrom implements interface [io.ReaderFrom], but with context awareness. +// +// This should allow efficient copying allowing writer or reader to define the chunk size. +func (w *copier) ReadFrom(r io.Reader) (n int64, err error) { + if _, ok := w.w.(io.ReaderFrom); ok { + // Let the original Writer decide the chunk size. + return io.Copy(w.writer.w, &reader{ctx: w.ctx, r: r}) + } + select { + case <-w.ctx.Done(): + return 0, w.ctx.Err() + default: + // The original Writer is not a ReaderFrom. + // Let the Reader decide the chunk size. + return io.Copy(&w.writer, r) + } +} + +// NewCloser wraps an [io.Reader] to handle context cancellation. +// +// Context state is checked BEFORE any Close. +func NewCloser(ctx context.Context, c io.Closer) io.Closer { + return &closer{ctx: ctx, c: c} +} + +type closer struct { + ctx context.Context + c io.Closer +} + +func (c *closer) Close() error { + select { + case <-c.ctx.Done(): + return c.ctx.Err() + default: + return c.c.Close() + } +} From 55cd16759ce4b91851c6aae7fb8eeced6a04dc1b Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 21:00:21 -0700 Subject: [PATCH 015/195] wip: progress on progress bar --- cli/cmd_xtest.go | 121 +++++++++++++++--- cli/logging.go | 13 +- libsq/core/lg/lga/lga.go | 1 + .../progressio.go => contextio.go} | 9 +- libsq/core/progress/progress.go | 71 +++++++--- libsq/core/progress/progressio.go | 118 +++++++++++++++++ libsq/source/files.go | 19 ++- libsq/source/source.go | 12 +- 8 files changed, 318 insertions(+), 46 deletions(-) rename libsq/core/progress/{progressio/progressio.go => contextio.go} (91%) create mode 100644 libsq/core/progress/progressio.go diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go index 161be3e19..663317c7d 100644 --- a/cli/cmd_xtest.go +++ b/cli/cmd_xtest.go @@ -62,8 +62,68 @@ func execXTestMbpIndeterminate(cmd *cobra.Command, _ []string) error { return ru.Writers.Version.Version(buildinfo.Get(), buildinfo.Get().Version, hostinfo.Get()) } -func doBigReadOg(ctx context.Context, errOut io.Writer) error { - pb := progress.New(ctx, errOut, 1*time.Second) +// +//func doBigReadOg(ctx context.Context, errOut io.Writer) error { +// pb := progress.New(ctx, errOut, 1*time.Second, progress.DefaultColors()) +// +// //total := 50 +// //p := mpb.New( +// // mpb.WithOutput(ru.Out), +// // mpb.WithWidth(64), +// // mpb.WithRenderDelay(after2(1*time.Second)), +// //) +// p := pb.P +// +// //bar := p.New( +// // int64(0), +// // mpb.BarStyle(), +// // mpb.PrependDecorators(decor.Name("huzzah")), +// // mpb.BarRemoveOnComplete(), +// //) +// +// bar := p.AddBar(0, +// mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f / % .1f")), +// //mpb.AppendDecorators(decor.Percentage()), +// mpb.BarRemoveOnComplete(), +// ) +// +// maxSleep := 100 * time.Millisecond +// +// jr := &junkReader{limit: 1000000} +// b := make([]byte, 1024) +// +//LOOP: +// for { +// select { +// case <-ctx.Done(): +// +// break LOOP +// default: +// } +// +// n, err := jr.Read(b) +// if err != nil { +// //bar.SetTotal(-1, true) +// if err == io.EOF { +// // triggering complete event now +// bar.SetTotal(-1, true) +// break +// } +// break +// } +// // increment methods won't trigger complete event because bar was constructed with total = 0 +// bar.IncrBy(n) +// // following call is not required, it's called to show some progress instead of an empty bar +// bar.SetTotal(bar.Current()+2048, false) +// time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) +// } +// +// p.Wait() +// return nil +//} + +func doBigRead2(ctx context.Context, pr *output.Printing, errOut io.Writer) error { + pb := progress.New(ctx, errOut, 1*time.Second, progress.DefaultColors()) //total := 50 //p := mpb.New( @@ -73,6 +133,9 @@ func doBigReadOg(ctx context.Context, errOut io.Writer) error { //) p := pb.P + //bar := pb.NewIOSpinner("ingest data") + spinner := pb.NewIOSpinner("Ingest data test...") + //bar := spinner.Bar //bar := p.New( // int64(0), // mpb.BarStyle(), @@ -80,40 +143,68 @@ func doBigReadOg(ctx context.Context, errOut io.Writer) error { // mpb.BarRemoveOnComplete(), //) - bar := p.AddBar(0, - mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f / % .1f")), - //mpb.AppendDecorators(decor.Percentage()), - mpb.BarRemoveOnComplete(), - ) + //s := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") + //spinnerStyle := s.Meta(func(s string) string { + // return pr.Active.Sprint(s) + // //return "\033[31m" + s + "\033[0m" // red + //}) + // + //bar := p.New(0, + // spinnerStyle, + // //mpb.PrependDecorators(), + // //mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f / % .1f")), + // //mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f")), + // mpb.PrependDecorators( + // progress.ColorMeta(decor.Name("Ingesting data..."), pr.Faint), + // ), + // //mpb.AppendDecorators(decor.Percentage()), + // mpb.AppendDecorators( + // progress.ColorMeta(decor.Current(decor.SizeB1024(0), "% .1f"), pr.Faint), + // ), + // //mpb.AppendDecorators( + // // // replace ETA decorator with "done" message, OnComplete event + // // decor.OnComplete( + // // // ETA decorator with ewma age of 30 + // // decor.EwmaETA(decor.ET_STYLE_GO, 30), "done", + // // ), + // //), + // mpb.BarRemoveOnComplete(), + //) maxSleep := 100 * time.Millisecond - jr := &junkReader{limit: 1000000} + jr := &junkReader{limit: 100000} b := make([]byte, 1024) + //start := time.Now() + LOOP: for { + //bar.EwmaIncrement(time.Since(start)) select { case <-ctx.Done(): - + //bar.SetTotal(-1, true) + spinner.Finish() break LOOP default: } n, err := jr.Read(b) if err != nil { + spinner.Finish() //bar.SetTotal(-1, true) if err == io.EOF { // triggering complete event now - bar.SetTotal(-1, true) + //bar.SetTotal(-1, true) break } break } // increment methods won't trigger complete event because bar was constructed with total = 0 - bar.IncrBy(n) + //bar.IncrBy(n) + spinner.IncrBy(n) // following call is not required, it's called to show some progress instead of an empty bar - bar.SetTotal(bar.Current()+2048, false) + //bar.SetTotal(bar.Current()+2048, false) time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) } @@ -121,8 +212,8 @@ LOOP: return nil } -func doBigRead2(ctx context.Context, pr *output.Printing, errOut io.Writer) error { - pb := progress.New(ctx, errOut, 1*time.Second) +func doBigReadWorking(ctx context.Context, pr *output.Printing, errOut io.Writer) error { + pb := progress.New(ctx, errOut, 1*time.Second, progress.DefaultColors()) //total := 50 //p := mpb.New( @@ -197,7 +288,7 @@ LOOP: // increment methods won't trigger complete event because bar was constructed with total = 0 bar.IncrBy(n) // following call is not required, it's called to show some progress instead of an empty bar - bar.SetTotal(bar.Current()+2048, false) + //bar.SetTotal(bar.Current()+2048, false) time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) } diff --git a/cli/logging.go b/cli/logging.go index 127bcaef5..a780d4e34 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -6,6 +6,7 @@ import ( "log/slog" "os" "path/filepath" + "strconv" "strings" "github.com/spf13/cobra" @@ -119,8 +120,16 @@ func slogReplaceAttrs(groups []string, a slog.Attr) slog.Attr { func slogReplaceSource(_ []string, a slog.Attr) slog.Attr { // We want source to be "pkg/file.go". if a.Key == slog.SourceKey { - fp := a.Value.String() - a.Value = slog.StringValue(filepath.Join(filepath.Base(filepath.Dir(fp)), filepath.Base(fp))) + source := a.Value.Any().(*slog.Source) + //source.File = filepath.Base(source.File) + + val := filepath.Join(filepath.Base(filepath.Dir(source.File)), filepath.Base(source.File)) + val += ":" + strconv.Itoa(source.Line) + a.Value = slog.StringValue(val) + + //src, ok := a.Value. + //fp := a.Value.String() + //a.Value = slog.StringValue(filepath.Join(filepath.Base(filepath.Dir(fp)), filepath.Base(fp))) } return a } diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index e916450ae..02d806933 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -3,6 +3,7 @@ package lga const ( + Action = "action" After = "after" Alt = "alt" Before = "before" diff --git a/libsq/core/progress/progressio/progressio.go b/libsq/core/progress/contextio.go similarity index 91% rename from libsq/core/progress/progressio/progressio.go rename to libsq/core/progress/contextio.go index ce97b4d59..03461ee9e 100644 --- a/libsq/core/progress/progressio/progressio.go +++ b/libsq/core/progress/contextio.go @@ -14,12 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package progressio provides [io.Writer] and [io.Reader] that stop accepting/providing -// data when an attached context is canceled, and know how to interact -// with a progress bar. -// -// The code is lifted from github.com/dolmen-go/contextio. -package progressio +// This code is lifted from github.com/dolmen-go/contextio. + +package progress import ( "context" diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index fe81a47fe..7f031bd4d 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -26,41 +26,76 @@ func FromContext(ctx context.Context) *Progress { } type Progress struct { - P *mpb.Progress + P *mpb.Progress + Colors *Colors } -func New(ctx context.Context, out io.Writer, delay time.Duration) *Progress { +func DefaultColors() *Colors { + return &Colors{ + Message: color.New(color.Faint), + Spinner: color.New(color.FgGreen, color.Bold), + Size: color.New(color.Faint), + } +} + +type Colors struct { + Message *color.Color + Spinner *color.Color + Size *color.Color +} + +func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors) *Progress { p := mpb.NewWithContext(ctx, mpb.WithOutput(out), mpb.WithWidth(64), mpb.WithRenderDelay(renderDelay(delay)), ) - return &Progress{P: p} -} + if colors == nil { + colors = DefaultColors() + } -func renderDelay(d time.Duration) <-chan struct{} { - ch := make(chan struct{}) - time.AfterFunc(d, func() { - close(ch) - }) - return ch + return &Progress{P: p, Colors: colors} } -func NewProxyBar(p *mpb.Progress) { - bar := p.New(0, - mpb.BarStyle().Rbound("|"), +func (p *Progress) NewIOSpinner(msg string) *IOSpinner { + s := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") + s = s.Meta(func(s string) string { + return p.Colors.Spinner.Sprint(s) + }) + bar := p.P.New(0, + s, + mpb.BarWidth(36), mpb.PrependDecorators( - decor.Counters(decor.SizeB1024(0), "% .2f / % .2f"), + ColorMeta(decor.Name(msg), p.Colors.Message), ), mpb.AppendDecorators( - decor.EwmaETA(decor.ET_STYLE_GO, 30), - decor.Name(" ] "), - decor.EwmaSpeed(decor.SizeB1024(0), "% .2f", 30), + ColorMeta(decor.Current(decor.SizeB1024(0), "% .1f"), p.Colors.Message), ), + mpb.BarRemoveOnComplete(), ) - _ = bar + return &IOSpinner{bar: bar} +} + +type IOSpinner struct { + bar *mpb.Bar +} + +func (sp *IOSpinner) IncrBy(n int) { + sp.bar.IncrBy(n) +} + +func (sp *IOSpinner) Finish() { + sp.bar.SetTotal(-1, true) +} + +func renderDelay(d time.Duration) <-chan struct{} { + ch := make(chan struct{}) + time.AfterFunc(d, func() { + close(ch) + }) + return ch } func ColorMeta(decorator decor.Decorator, c *color.Color) decor.Decorator { diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go new file mode 100644 index 000000000..d920cd228 --- /dev/null +++ b/libsq/core/progress/progressio.go @@ -0,0 +1,118 @@ +/* +Copyright 2018 Olivier Mengué + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// This code is lifted from github.com/dolmen-go/contextio. + +package progress + +import ( + "context" + "io" +) + +type progWriter struct { + ctx context.Context + w io.Writer +} + +type progCopier struct { + progWriter +} + +// NewProgWriter wraps an [io.Writer] to handle context cancellation. +// +// Context state is checked BEFORE every Write. +// +// The returned Writer also implements [io.ReaderFrom] to allow [io.Copy] to select +// the best strategy while still checking the context state before every chunk transfer. +func NewProgWriter(ctx context.Context, msg string, w io.Writer) io.Writer { + if w, ok := w.(*progCopier); ok && ctx == w.ctx { + return w + } + return &progCopier{progWriter{ctx: ctx, w: w}} +} + +// Write implements [io.Writer], but with context awareness. +func (w *progWriter) Write(p []byte) (n int, err error) { + select { + case <-w.ctx.Done(): + return 0, w.ctx.Err() + default: + return w.w.Write(p) + } +} + +func (w *progWriter) Close() error { + // REVISIT: I'm not sure if we should always try + // to close the underlying writer first, even if + // the context is done? Or go straight to the + // select ctx.Done? + + var closeErr error + if wc, ok := w.w.(io.WriteCloser); ok { + closeErr = wc.Close() + } + + select { + case <-w.ctx.Done(): + return w.ctx.Err() + default: + } + + return closeErr +} + +type progReader struct { + ctx context.Context + r io.Reader +} + +// NewProgReader wraps an [io.Reader] to handle context cancellation. +// +// Context state is checked BEFORE every Read. +func NewProgReader(ctx context.Context, msg string, r io.Reader) io.Reader { + if r, ok := r.(*progReader); ok && ctx == r.ctx { + return r + } + return &progReader{ctx: ctx, r: r} +} + +func (r *progReader) Read(p []byte) (n int, err error) { + select { + case <-r.ctx.Done(): + return 0, r.ctx.Err() + default: + return r.r.Read(p) + } +} + +// ReadFrom implements interface [io.ReaderFrom], but with context awareness. +// +// This should allow efficient copying allowing writer or reader to define the chunk size. +func (w *progCopier) ReadFrom(r io.Reader) (n int64, err error) { + if _, ok := w.w.(io.ReaderFrom); ok { + // Let the original Writer decide the chunk size. + return io.Copy(w.progWriter.w, &progReader{ctx: w.ctx, r: r}) + } + select { + case <-w.ctx.Done(): + return 0, w.ctx.Err() + default: + // The original Writer is not a ReaderFrom. + // Let the Reader decide the chunk size. + return io.Copy(&w.progWriter, r) + } +} diff --git a/libsq/source/files.go b/libsq/source/files.go index d046cef76..55c91328c 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -2,6 +2,8 @@ package source import ( "context" + "github.com/dolmen-go/contextio" + "github.com/neilotoole/sq/libsq/core/progress" "io" "log/slog" "mime" @@ -149,12 +151,27 @@ func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.R return nil, errz.Errorf("failed to add to fscache (possibly previously added): %s", key) } + copier := fscache.Filler{ + Message: "Cache fill", + Log: log.With(lga.Action, "Cache fill"), + NewContextWriterFn: progress.NewProgWriter, + NewContextReaderFn: func(ctx context.Context, msg string, r io.Reader) io.Reader { + return contextio.NewReader(ctx, r) + }, + CloseReader: true, + } + // TODO: Problematically, we copy the entire contents of f into fscache. // If f is a large file (e.g. piped over stdin), this means that // everything is held up until f is fully copied. Hopefully we can // do something with fscache so that the readers returned from // fscache can lazily read from f. - if err = fscache.FillWriterAsync(ctx, log, w, f, true); err != nil { + //if err = fscache.FillWriterAsync(ctx, log, w, f, true); err != nil { + // lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) + // return nil, errz.Err(err) + //} + + if err = copier.Copy(ctx, w, f); err != nil { lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) return nil, errz.Err(err) } diff --git a/libsq/source/source.go b/libsq/source/source.go index adc6a522b..4092331f4 100644 --- a/libsq/source/source.go +++ b/libsq/source/source.go @@ -117,7 +117,7 @@ func (s *Source) LogValue() slog.Value { return slog.Value{} } - attrs := make([]slog.Attr, 3, 6) + attrs := make([]slog.Attr, 3, 5) attrs[0] = slog.String(lga.Handle, s.Handle) attrs[1] = slog.String(lga.Driver, string(s.Type)) attrs[2] = slog.String(lga.Loc, s.RedactedLocation()) @@ -127,9 +127,13 @@ func (s *Source) LogValue() slog.Value { if s.Schema != "" { attrs = append(attrs, slog.String(lga.Schema, s.Schema)) } - if s.Options != nil { - attrs = append(attrs, slog.Any(lga.Opts, s.Options)) - } + + // For really intense debugging, we can log the options. + // But it's too much for normal logging. + // + // if s.Options != nil { + // attrs = append(attrs, slog.Any(lga.Opts, s.Options)) + // } return slog.GroupValue(attrs...) } From 75d71d624405c16ba54ed3fdc0849ad7c4860e63 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 21:23:07 -0700 Subject: [PATCH 016/195] wip: more progress --- cli/cmd_xtest.go | 232 +++++--------------------------- libsq/core/progress/progress.go | 49 +++++-- 2 files changed, 71 insertions(+), 210 deletions(-) diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go index 663317c7d..d2dddb079 100644 --- a/cli/cmd_xtest.go +++ b/cli/cmd_xtest.go @@ -6,12 +6,9 @@ import ( "github.com/neilotoole/sq/cli/buildinfo" "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/hostinfo" - "github.com/neilotoole/sq/cli/output" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq/core/progress" "github.com/spf13/cobra" - "github.com/vbauerster/mpb/v8" - "github.com/vbauerster/mpb/v8/decor" "io" "math/rand" "time" @@ -39,7 +36,10 @@ func execXTestMbp(cmd *cobra.Command, _ []string) error { fmt.Fprintln(ru.Out, "Hello, world!") - if err := doBigRead2(ctx, ru.Writers.Printing, ru.ErrOut); err != nil { + pb := progress.New(ctx, ru.ErrOut, 1*time.Second, progress.DefaultColors()) + ctx = progress.NewContext(ctx, pb) + + if err := doBigRead2(ctx); err != nil { return err } @@ -49,17 +49,39 @@ func execXTestMbp(cmd *cobra.Command, _ []string) error { } -func execXTestMbpIndeterminate(cmd *cobra.Command, _ []string) error { - ctx := cmd.Context() - ru := run.FromContext(ctx) +func doBigRead2(ctx context.Context) error { + pb := progress.FromContext(ctx) - //bar.Abort(true) - // wait for our bar to complete and flush - //p.Wait() - //bar.Abort(true) - //p.Shutdown() + spinner := pb.NewIOSpinner("Ingest data test...") + defer spinner.Finish() + maxSleep := 100 * time.Millisecond + jr := &junkReader{limit: 100000} + b := make([]byte, 1024) - return ru.Writers.Version.Version(buildinfo.Get(), buildinfo.Get().Version, hostinfo.Get()) +LOOP: + for { + select { + case <-ctx.Done(): + spinner.Finish() + break LOOP + default: + } + + n, err := jr.Read(b) + if err != nil { + spinner.Finish() + if err == io.EOF { + break + } + break + } + + spinner.IncrBy(n) + time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) + } + + //pb.Wait() + return nil } // @@ -122,190 +144,6 @@ func execXTestMbpIndeterminate(cmd *cobra.Command, _ []string) error { // return nil //} -func doBigRead2(ctx context.Context, pr *output.Printing, errOut io.Writer) error { - pb := progress.New(ctx, errOut, 1*time.Second, progress.DefaultColors()) - - //total := 50 - //p := mpb.New( - // mpb.WithOutput(ru.Out), - // mpb.WithWidth(64), - // mpb.WithRenderDelay(after2(1*time.Second)), - //) - p := pb.P - - //bar := pb.NewIOSpinner("ingest data") - spinner := pb.NewIOSpinner("Ingest data test...") - //bar := spinner.Bar - //bar := p.New( - // int64(0), - // mpb.BarStyle(), - // mpb.PrependDecorators(decor.Name("huzzah")), - // mpb.BarRemoveOnComplete(), - //) - - //s := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") - //spinnerStyle := s.Meta(func(s string) string { - // return pr.Active.Sprint(s) - // //return "\033[31m" + s + "\033[0m" // red - //}) - // - //bar := p.New(0, - // spinnerStyle, - // //mpb.PrependDecorators(), - // //mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f / % .1f")), - // //mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f")), - // mpb.PrependDecorators( - // progress.ColorMeta(decor.Name("Ingesting data..."), pr.Faint), - // ), - // //mpb.AppendDecorators(decor.Percentage()), - // mpb.AppendDecorators( - // progress.ColorMeta(decor.Current(decor.SizeB1024(0), "% .1f"), pr.Faint), - // ), - // //mpb.AppendDecorators( - // // // replace ETA decorator with "done" message, OnComplete event - // // decor.OnComplete( - // // // ETA decorator with ewma age of 30 - // // decor.EwmaETA(decor.ET_STYLE_GO, 30), "done", - // // ), - // //), - // mpb.BarRemoveOnComplete(), - //) - - maxSleep := 100 * time.Millisecond - - jr := &junkReader{limit: 100000} - b := make([]byte, 1024) - - //start := time.Now() - -LOOP: - for { - //bar.EwmaIncrement(time.Since(start)) - select { - case <-ctx.Done(): - //bar.SetTotal(-1, true) - spinner.Finish() - break LOOP - default: - } - - n, err := jr.Read(b) - if err != nil { - spinner.Finish() - //bar.SetTotal(-1, true) - if err == io.EOF { - // triggering complete event now - //bar.SetTotal(-1, true) - break - } - break - } - // increment methods won't trigger complete event because bar was constructed with total = 0 - //bar.IncrBy(n) - spinner.IncrBy(n) - // following call is not required, it's called to show some progress instead of an empty bar - //bar.SetTotal(bar.Current()+2048, false) - time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) - } - - p.Wait() - return nil -} - -func doBigReadWorking(ctx context.Context, pr *output.Printing, errOut io.Writer) error { - pb := progress.New(ctx, errOut, 1*time.Second, progress.DefaultColors()) - - //total := 50 - //p := mpb.New( - // mpb.WithOutput(ru.Out), - // mpb.WithWidth(64), - // mpb.WithRenderDelay(after2(1*time.Second)), - //) - p := pb.P - - //bar := p.New( - // int64(0), - // mpb.BarStyle(), - // mpb.PrependDecorators(decor.Name("huzzah")), - // mpb.BarRemoveOnComplete(), - //) - - s := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") - spinnerStyle := s.Meta(func(s string) string { - return pr.Active.Sprint(s) - //return "\033[31m" + s + "\033[0m" // red - }) - - bar := p.New(0, - spinnerStyle, - //mpb.PrependDecorators(), - //mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f / % .1f")), - //mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f")), - mpb.PrependDecorators( - progress.ColorMeta(decor.Name("Ingesting data..."), pr.Faint), - ), - //mpb.AppendDecorators(decor.Percentage()), - mpb.AppendDecorators( - progress.ColorMeta(decor.Current(decor.SizeB1024(0), "% .1f"), pr.Faint), - ), - //mpb.AppendDecorators( - // // replace ETA decorator with "done" message, OnComplete event - // decor.OnComplete( - // // ETA decorator with ewma age of 30 - // decor.EwmaETA(decor.ET_STYLE_GO, 30), "done", - // ), - //), - mpb.BarRemoveOnComplete(), - ) - - maxSleep := 100 * time.Millisecond - - jr := &junkReader{limit: 100000} - b := make([]byte, 1024) - - //start := time.Now() - -LOOP: - for { - //bar.EwmaIncrement(time.Since(start)) - select { - case <-ctx.Done(): - bar.SetTotal(-1, true) - break LOOP - default: - } - - n, err := jr.Read(b) - if err != nil { - bar.SetTotal(-1, true) - if err == io.EOF { - // triggering complete event now - //bar.SetTotal(-1, true) - break - } - break - } - // increment methods won't trigger complete event because bar was constructed with total = 0 - bar.IncrBy(n) - // following call is not required, it's called to show some progress instead of an empty bar - //bar.SetTotal(bar.Current()+2048, false) - time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) - } - - p.Wait() - return nil -} - -func makeStream(limit int) func() (int, error) { - return func() (int, error) { - if limit <= 0 { - return 0, io.EOF - } - limit-- - return rand.Intn(1024) + 1, nil - } -} - type junkReader struct { limit int count int diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 7f031bd4d..e4ebe01d4 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -25,11 +25,6 @@ func FromContext(ctx context.Context) *Progress { return ctx.Value(runKey{}).(*Progress) } -type Progress struct { - P *mpb.Progress - Colors *Colors -} - func DefaultColors() *Colors { return &Colors{ Message: color.New(color.Faint), @@ -44,6 +39,24 @@ type Colors struct { Size *color.Color } +type Progress struct { + p *mpb.Progress + colors *Colors +} + +// Wait waits for all bars to complete and finally shutdowns container. After +// this method has been called, there is no way to reuse `*Progress` instance. +func (p *Progress) Wait() { + p.p.Wait() +} + +//// Shutdown cancels any running bar immediately and then shutdowns `*Progress` +//// instance. Normally this method shouldn't be called unless you know what you +//// are doing. Proper way to shutdown is to call `(*Progress).Wait()` instead. +//func (p *Progress) Shutdown() { +// p.p.Shutdown() +//} + func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors) *Progress { p := mpb.NewWithContext(ctx, mpb.WithOutput(out), @@ -55,22 +68,26 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors colors = DefaultColors() } - return &Progress{P: p, Colors: colors} + return &Progress{p: p, colors: colors} } func (p *Progress) NewIOSpinner(msg string) *IOSpinner { - s := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") - s = s.Meta(func(s string) string { - return p.Colors.Spinner.Sprint(s) + if p == nil { + return nil + } + style := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") + style = style.Meta(func(s string) string { + return p.colors.Spinner.Sprint(s) }) - bar := p.P.New(0, - s, + + bar := p.p.New(0, + style, mpb.BarWidth(36), mpb.PrependDecorators( - ColorMeta(decor.Name(msg), p.Colors.Message), + ColorMeta(decor.Name(msg), p.colors.Message), ), mpb.AppendDecorators( - ColorMeta(decor.Current(decor.SizeB1024(0), "% .1f"), p.Colors.Message), + ColorMeta(decor.Current(decor.SizeB1024(0), "% .1f"), p.colors.Message), ), mpb.BarRemoveOnComplete(), ) @@ -83,10 +100,16 @@ type IOSpinner struct { } func (sp *IOSpinner) IncrBy(n int) { + if sp == nil { + return + } sp.bar.IncrBy(n) } func (sp *IOSpinner) Finish() { + if sp == nil { + return + } sp.bar.SetTotal(-1, true) } From 67c06a9e7cd664b2756a21c34ace2e3c67a79b1a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 22:17:31 -0700 Subject: [PATCH 017/195] wip: more progress --- cli/logging.go | 13 ++++---- cli/output.go | 14 ++++++++ cli/run.go | 2 ++ cli/term.go | 2 ++ libsq/core/progress/progress.go | 34 ++++++++++++++++++- libsq/core/progress/progressio.go | 54 +++++++++++++++++++++++++------ 6 files changed, 102 insertions(+), 17 deletions(-) diff --git a/cli/logging.go b/cli/logging.go index a780d4e34..305e15b6f 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -118,15 +118,16 @@ func slogReplaceAttrs(groups []string, a slog.Attr) slog.Attr { // slogReplaceSource overrides the default slog.SourceKey attr // to print "pkg/file.go" instead. func slogReplaceSource(_ []string, a slog.Attr) slog.Attr { - // We want source to be "pkg/file.go". + // We want source to be "pkg/file.go:42". if a.Key == slog.SourceKey { - source := a.Value.Any().(*slog.Source) + source, ok := a.Value.Any().(*slog.Source) + if ok && source != nil { + val := filepath.Join(filepath.Base(filepath.Dir(source.File)), filepath.Base(source.File)) + val += ":" + strconv.Itoa(source.Line) + a.Value = slog.StringValue(val) + } //source.File = filepath.Base(source.File) - val := filepath.Join(filepath.Base(filepath.Dir(source.File)), filepath.Base(source.File)) - val += ":" + strconv.Itoa(source.Line) - a.Value = slog.StringValue(val) - //src, ok := a.Value. //fp := a.Value.String() //a.Value = slog.StringValue(filepath.Join(filepath.Base(filepath.Dir(fp)), filepath.Base(fp))) diff --git a/cli/output.go b/cli/output.go index c77f5e608..a56ca8b9a 100644 --- a/cli/output.go +++ b/cli/output.go @@ -2,9 +2,11 @@ package cli import ( "fmt" + "github.com/neilotoole/sq/libsq/core/progress" "io" "os" "strings" + "time" "github.com/fatih/color" colorable "github.com/mattn/go-colorable" @@ -406,6 +408,18 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer errOut2 = colorable.NewNonColorable(errOut) } + if isTerminal(errOut) { + // FIXME: need to check option for --progress + progColors := progress.DefaultColors() + progColors.EnableColor(isColorTerminal(errOut)) + + ctx := cmd.Context() + // TODO: need to check for option "progress.delay" + prog := progress.New(ctx, errOut, time.Second*2, progColors) + cmd.SetContext(progress.NewContext(ctx, prog)) + logFrom(cmd).Debug("Initialized progress") + } + logFrom(cmd).Debug("Constructed output.Printing", lga.Val, pr) return pr, out2, errOut2 diff --git a/cli/run.go b/cli/run.go index 6f5c1c5c3..802715831 100644 --- a/cli/run.go +++ b/cli/run.go @@ -124,6 +124,8 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { ru.Cleanup = cleanup.New() } + + cfg, log := ru.Config, lg.FromContext(ctx) var scratchSrcFunc driver.ScratchSrcFunc diff --git a/cli/term.go b/cli/term.go index 4e7f17d3b..91f9cb120 100644 --- a/cli/term.go +++ b/cli/term.go @@ -24,6 +24,8 @@ func isColorTerminal(w io.Writer) bool { return false } + // TODO: Add the improvements from jsoncolor: + // https://github.com/neilotoole/jsoncolor/pull/27 if !isTerminal(w) { return false } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index e4ebe01d4..b08bf26d6 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -22,7 +22,22 @@ func NewContext(ctx context.Context, prog *Progress) context.Context { // FromContext extracts the Progress added to ctx via NewContext. func FromContext(ctx context.Context) *Progress { - return ctx.Value(runKey{}).(*Progress) + if ctx == nil { + return nil + } + + val := ctx.Value(runKey{}) + if val == nil { + return nil + } + + if p, ok := val.(*Progress); ok { + return p + } + + return nil + + //return ctx.Value(runKey{}).(*Progress) } func DefaultColors() *Colors { @@ -39,6 +54,23 @@ type Colors struct { Size *color.Color } +func (c *Colors) EnableColor(enable bool) { + if c == nil { + return + } + + if enable { + c.Message.EnableColor() + c.Spinner.EnableColor() + c.Size.EnableColor() + return + } + + c.Message.DisableColor() + c.Spinner.DisableColor() + c.Size.DisableColor() +} + type Progress struct { p *mpb.Progress colors *Colors diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index d920cd228..aaec087e1 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -24,8 +24,9 @@ import ( ) type progWriter struct { - ctx context.Context - w io.Writer + ctx context.Context + w io.Writer + spinner *IOSpinner } type progCopier struct { @@ -42,20 +43,36 @@ func NewProgWriter(ctx context.Context, msg string, w io.Writer) io.Writer { if w, ok := w.(*progCopier); ok && ctx == w.ctx { return w } - return &progCopier{progWriter{ctx: ctx, w: w}} + + pb := FromContext(ctx) + spinner := pb.NewIOSpinner(msg) + + return &progCopier{progWriter{ctx: ctx, w: w, spinner: spinner}} } // Write implements [io.Writer], but with context awareness. func (w *progWriter) Write(p []byte) (n int, err error) { select { case <-w.ctx.Done(): + w.spinner.Finish() return 0, w.ctx.Err() default: - return w.w.Write(p) + n, err = w.w.Write(p) + w.spinner.IncrBy(n) + if err != nil { + w.spinner.Finish() + } + return n, err } } func (w *progWriter) Close() error { + if w == nil { + return nil + } + + w.spinner.Finish() + // REVISIT: I'm not sure if we should always try // to close the underlying writer first, even if // the context is done? Or go straight to the @@ -76,8 +93,9 @@ func (w *progWriter) Close() error { } type progReader struct { - ctx context.Context - r io.Reader + ctx context.Context + r io.Reader + spinner *IOSpinner } // NewProgReader wraps an [io.Reader] to handle context cancellation. @@ -87,15 +105,23 @@ func NewProgReader(ctx context.Context, msg string, r io.Reader) io.Reader { if r, ok := r.(*progReader); ok && ctx == r.ctx { return r } - return &progReader{ctx: ctx, r: r} + + spinner := FromContext(ctx).NewIOSpinner(msg) + return &progReader{ctx: ctx, r: r, spinner: spinner} } func (r *progReader) Read(p []byte) (n int, err error) { select { case <-r.ctx.Done(): + r.spinner.Finish() return 0, r.ctx.Err() default: - return r.r.Read(p) + n, err = r.r.Read(p) + r.spinner.IncrBy(n) + if err != nil { + r.spinner.Finish() + } + return n, err } } @@ -105,14 +131,22 @@ func (r *progReader) Read(p []byte) (n int, err error) { func (w *progCopier) ReadFrom(r io.Reader) (n int64, err error) { if _, ok := w.w.(io.ReaderFrom); ok { // Let the original Writer decide the chunk size. - return io.Copy(w.progWriter.w, &progReader{ctx: w.ctx, r: r}) + // FIXME: Do we really need to pass the spinner to progReader, if + // the writer already has it? + return io.Copy(w.progWriter.w, &progReader{ctx: w.ctx, r: r, spinner: w.spinner}) } select { case <-w.ctx.Done(): + w.spinner.Finish() return 0, w.ctx.Err() default: // The original Writer is not a ReaderFrom. // Let the Reader decide the chunk size. - return io.Copy(&w.progWriter, r) + n, err = io.Copy(&w.progWriter, r) + w.spinner.IncrBy(int(n)) + if err != nil { + w.spinner.Finish() + } + return n, err } } From c7bf3aa50b3ad6a1dfcc6cc5381f71cd78c816ce Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 26 Nov 2023 22:24:50 -0700 Subject: [PATCH 018/195] still broken --- libsq/core/progress/progressio.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index aaec087e1..1f97ba7f4 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -46,7 +46,9 @@ func NewProgWriter(ctx context.Context, msg string, w io.Writer) io.Writer { pb := FromContext(ctx) spinner := pb.NewIOSpinner(msg) - + return spinner.bar.ProxyWriter(w) + // FIXME: This is not working, ^^, bar stays on screen + //return &progCopier{progWriter{ctx: ctx, w: spinner.bar.ProxyWriter(w), spinner: spinner}} return &progCopier{progWriter{ctx: ctx, w: w, spinner: spinner}} } From c0428a1ca4cad86cdf09f75c761be6d10817c977 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Mon, 27 Nov 2023 07:56:53 -0700 Subject: [PATCH 019/195] wip: almost working --- cli/output.go | 1 + libsq/core/progress/progress.go | 54 ++++++++++++++++++++++++++++--- libsq/core/progress/progressio.go | 2 +- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/cli/output.go b/cli/output.go index a56ca8b9a..3ad1c8f3f 100644 --- a/cli/output.go +++ b/cli/output.go @@ -416,6 +416,7 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer ctx := cmd.Context() // TODO: need to check for option "progress.delay" prog := progress.New(ctx, errOut, time.Second*2, progColors) + out2 = prog.ShutdownOnWriteTo(out2) cmd.SetContext(progress.NewContext(ctx, prog)) logFrom(cmd).Debug("Initialized progress") } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index b08bf26d6..da2c7873b 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -3,9 +3,11 @@ package progress import ( "context" "github.com/fatih/color" + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" "io" + "sync" "time" ) @@ -72,13 +74,19 @@ func (c *Colors) EnableColor(enable bool) { } type Progress struct { - p *mpb.Progress - colors *Colors + p *mpb.Progress + mu *sync.Mutex + colors *Colors + cleanup *cleanup.Cleanup } // Wait waits for all bars to complete and finally shutdowns container. After // this method has been called, there is no way to reuse `*Progress` instance. func (p *Progress) Wait() { + p.mu.Lock() + defer p.mu.Unlock() + + _ = p.cleanup.Run() p.p.Wait() } @@ -100,13 +108,48 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors colors = DefaultColors() } - return &Progress{p: p, colors: colors} + return &Progress{p: p, colors: colors, mu: &sync.Mutex{}, cleanup: cleanup.New()} +} + +var _ io.Writer = (*writeNotifier)(nil) + +type writeNotifier struct { + p *Progress + w io.Writer } +func (w *writeNotifier) Write(p []byte) (n int, err error) { + w.p.Wait() + return w.w.Write(p) +} + +// ShutdownOnWriteTo returns a writer that will shut down the +// progress bar when w.WriteTo is called. Typically p writes +// to stderr, and stdout is passed to this method. That is, when +// the program starts writing to stdout, we want to shut down +// and remove the progress bar. +func (p *Progress) ShutdownOnWriteTo(w io.Writer) io.Writer { + if p == nil { + return w + } + return &writeNotifier{ + p: p, + w: w, + } +} + +// NewIOSpinner returns a new spinner bar. The caller is ultimately +// responsible for calling Finish() on the returned IOSpinner. However, +// the returned IOSpinner is added to the Progress's cleanup list, so +// it will be called automatically when the Progress is shut down. func (p *Progress) NewIOSpinner(msg string) *IOSpinner { if p == nil { return nil } + + p.mu.Lock() + defer p.mu.Unlock() + style := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") style = style.Meta(func(s string) string { return p.colors.Spinner.Sprint(s) @@ -124,7 +167,9 @@ func (p *Progress) NewIOSpinner(msg string) *IOSpinner { mpb.BarRemoveOnComplete(), ) - return &IOSpinner{bar: bar} + spinner := &IOSpinner{bar: bar} + p.cleanup.Add(spinner.Finish) + return spinner } type IOSpinner struct { @@ -143,6 +188,7 @@ func (sp *IOSpinner) Finish() { return } sp.bar.SetTotal(-1, true) + sp.bar.Wait() } func renderDelay(d time.Duration) <-chan struct{} { diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index 1f97ba7f4..b9df949e6 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -49,7 +49,7 @@ func NewProgWriter(ctx context.Context, msg string, w io.Writer) io.Writer { return spinner.bar.ProxyWriter(w) // FIXME: This is not working, ^^, bar stays on screen //return &progCopier{progWriter{ctx: ctx, w: spinner.bar.ProxyWriter(w), spinner: spinner}} - return &progCopier{progWriter{ctx: ctx, w: w, spinner: spinner}} + //return &progCopier{progWriter{ctx: ctx, w: w, spinner: spinner}} } // Write implements [io.Writer], but with context awareness. From 05804a8b47cf929dba1abd921ed35ca7dd09ad65 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Mon, 27 Nov 2023 17:28:45 -0700 Subject: [PATCH 020/195] wip: close to working --- README.md | 2 + cli/cmd_xtest.go | 110 ++--------- cli/complete.go | 2 +- cli/logging.go | 9 +- cli/options.go | 19 +- cli/output.go | 37 +++- cli/run.go | 2 - drivers/csv/csv_test.go | 4 +- drivers/csv/ingest.go | 1 + drivers/sqlite3/sqlite3.go | 6 +- go.mod | 2 +- libsq/core/errz/errz.go | 13 ++ .../{progress => ioz/contextio}/contextio.go | 2 +- libsq/core/ioz/ioz.go | 63 +++++++ libsq/core/ioz/ioz_test.go | 27 +++ libsq/core/options/opt.go | 23 ++- libsq/core/options/options.go | 1 - libsq/core/options/options_test.go | 2 +- libsq/core/progress/progress.go | 64 +++---- libsq/core/progress/progress_test.go | 60 ++++++ libsq/core/progress/progressio.go | 171 +++++++++++++----- libsq/driver/ingest.go | 4 +- libsq/driver/sources.go | 6 +- libsq/source/files.go | 94 ++-------- libsq/source/source.go | 3 + 25 files changed, 446 insertions(+), 281 deletions(-) rename libsq/core/{progress => ioz/contextio}/contextio.go (99%) diff --git a/README.md b/README.md index f02b10412..bd87be215 100644 --- a/README.md +++ b/README.md @@ -319,6 +319,8 @@ See [CHANGELOG.md](./CHANGELOG.md). from [jOOQ](https://github.com/jooq/jooq), which in turn owe their heritage to earlier work on Sakila. - Date rendering via [`ncruces/go-strftime`](https://github.com/ncruces/go-strftime). +- The [`dolmen-go/contextio`](https://github.com/dolmen-go/contextio) package is + incorporated into the codebase (with modifications). ## Similar, related, or noteworthy projects diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go index d2dddb079..8ca195f66 100644 --- a/cli/cmd_xtest.go +++ b/cli/cmd_xtest.go @@ -3,15 +3,18 @@ package cli import ( "context" "fmt" + "math/rand" + "time" + + "github.com/neilotoole/sq/libsq/core/ioz" + + "github.com/spf13/cobra" + "github.com/neilotoole/sq/cli/buildinfo" "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/hostinfo" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq/core/progress" - "github.com/spf13/cobra" - "io" - "math/rand" - "time" ) func newXTestCmd() *cobra.Command { @@ -32,131 +35,48 @@ func newXTestCmd() *cobra.Command { func execXTestMbp(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() ru := run.FromContext(ctx) - //pb := progress.New(ctx, ru.ErrOut, 1*time.Second) fmt.Fprintln(ru.Out, "Hello, world!") - pb := progress.New(ctx, ru.ErrOut, 1*time.Second, progress.DefaultColors()) + pb := progress.New(ctx, ru.ErrOut, 1*time.Millisecond, progress.DefaultColors()) ctx = progress.NewContext(ctx, pb) if err := doBigRead2(ctx); err != nil { return err } - //_ = pb - return ru.Writers.Version.Version(buildinfo.Get(), buildinfo.Get().Version, hostinfo.Get()) - } func doBigRead2(ctx context.Context) error { pb := progress.FromContext(ctx) spinner := pb.NewIOSpinner("Ingest data test...") - defer spinner.Finish() + defer spinner.Stop() maxSleep := 100 * time.Millisecond - jr := &junkReader{limit: 100000} + + jr := ioz.LimitRandReader(100000) b := make([]byte, 1024) LOOP: for { select { case <-ctx.Done(): - spinner.Finish() + spinner.Stop() break LOOP default: } n, err := jr.Read(b) if err != nil { - spinner.Finish() - if err == io.EOF { - break - } + spinner.Stop() break } spinner.IncrBy(n) - time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) + time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) //nolint:gosec } - //pb.Wait() + pb.Wait() return nil } - -// -//func doBigReadOg(ctx context.Context, errOut io.Writer) error { -// pb := progress.New(ctx, errOut, 1*time.Second, progress.DefaultColors()) -// -// //total := 50 -// //p := mpb.New( -// // mpb.WithOutput(ru.Out), -// // mpb.WithWidth(64), -// // mpb.WithRenderDelay(after2(1*time.Second)), -// //) -// p := pb.P -// -// //bar := p.New( -// // int64(0), -// // mpb.BarStyle(), -// // mpb.PrependDecorators(decor.Name("huzzah")), -// // mpb.BarRemoveOnComplete(), -// //) -// -// bar := p.AddBar(0, -// mpb.PrependDecorators(decor.Counters(decor.SizeB1024(0), "% .1f / % .1f")), -// //mpb.AppendDecorators(decor.Percentage()), -// mpb.BarRemoveOnComplete(), -// ) -// -// maxSleep := 100 * time.Millisecond -// -// jr := &junkReader{limit: 1000000} -// b := make([]byte, 1024) -// -//LOOP: -// for { -// select { -// case <-ctx.Done(): -// -// break LOOP -// default: -// } -// -// n, err := jr.Read(b) -// if err != nil { -// //bar.SetTotal(-1, true) -// if err == io.EOF { -// // triggering complete event now -// bar.SetTotal(-1, true) -// break -// } -// break -// } -// // increment methods won't trigger complete event because bar was constructed with total = 0 -// bar.IncrBy(n) -// // following call is not required, it's called to show some progress instead of an empty bar -// bar.SetTotal(bar.Current()+2048, false) -// time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) -// } -// -// p.Wait() -// return nil -//} - -type junkReader struct { - limit int - count int -} - -func (r *junkReader) Read(p []byte) (n int, err error) { - if r.count >= r.limit { - return 0, io.EOF - } - - amount, err := rand.Read(p) - r.count += amount - return amount, err - - //return rand.Intn(1024) + 1, nil -} diff --git a/cli/complete.go b/cli/complete.go index d5b5a8e7e..2ab4656de 100644 --- a/cli/complete.go +++ b/cli/complete.go @@ -29,7 +29,7 @@ var OptShellCompletionTimeout = options.NewDuration( "", 0, time.Millisecond*500, - "shell completion timeout", + "Shell completion timeout", `How long shell completion should wait before giving up. This can become relevant when shell completion inspects a source's metadata, e.g. to offer a list of tables in a source.`, diff --git a/cli/logging.go b/cli/logging.go index 305e15b6f..27d75fcbb 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -25,6 +25,7 @@ var ( OptLogEnabled = options.NewBool( "log", "", + false, 0, false, "Enable logging", @@ -126,11 +127,11 @@ func slogReplaceSource(_ []string, a slog.Attr) slog.Attr { val += ":" + strconv.Itoa(source.Line) a.Value = slog.StringValue(val) } - //source.File = filepath.Base(source.File) + // source.File = filepath.Base(source.File) - //src, ok := a.Value. - //fp := a.Value.String() - //a.Value = slog.StringValue(filepath.Join(filepath.Base(filepath.Dir(fp)), filepath.Base(fp))) + // src, ok := a.Value. + // fp := a.Value.String() + // a.Value = slog.StringValue(filepath.Join(filepath.Base(filepath.Dir(fp)), filepath.Base(fp))) } return a } diff --git a/cli/options.go b/cli/options.go index 9ab769955..038f87614 100644 --- a/cli/options.go +++ b/cli/options.go @@ -41,7 +41,22 @@ func getOptionsFromFlags(flags *pflag.FlagSet, reg *options.Registry) (options.O return nil } - o[opt.Key()] = f.Value.String() + if bOpt, ok := opt.(options.Bool); ok { + // Special handling for bool, because + // the flag value could be inverted. + val, err := flags.GetBool(bOpt.Flag()) + if err != nil { + return errz.Err(err) + } + + if bOpt.FlagInverted() { + val = !val + } + o[bOpt.Key()] = val + } else { + o[opt.Key()] = f.Value.String() + } + return nil }) if err != nil { @@ -148,6 +163,8 @@ func RegisterDefaultOpts(reg *options.Registry) { OptVerbose, OptPrintHeader, OptMonochrome, + OptProgress, + OptProgressDelay, OptCompact, OptPingCmdTimeout, OptShellCompletionTimeout, diff --git a/cli/output.go b/cli/output.go index 3ad1c8f3f..a96b43e51 100644 --- a/cli/output.go +++ b/cli/output.go @@ -2,7 +2,6 @@ package cli import ( "fmt" - "github.com/neilotoole/sq/libsq/core/progress" "io" "os" "strings" @@ -28,6 +27,7 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/core/timez" ) @@ -36,6 +36,7 @@ var ( OptPrintHeader = options.NewBool( "header", "", + false, 0, true, "Print header row", @@ -79,6 +80,7 @@ command, sq falls back to "text". Available formats: OptVerbose = options.NewBool( "verbose", "", + false, 'v', false, "Print verbose output", @@ -89,6 +91,7 @@ command, sq falls back to "text". Available formats: OptMonochrome = options.NewBool( "monochrome", "", + false, 'M', false, "Don't print color output", @@ -96,9 +99,31 @@ command, sq falls back to "text". Available formats: options.TagOutput, ) + OptProgress = options.NewBool( + "progress", + "no-progress", + true, + 0, + true, + "Specify whether a progress bar is shown for long-running operations", + `Specify whether a progress bar is shown for long-running operations.`, + options.TagOutput, + ) + + OptProgressDelay = options.NewDuration( + "progress.delay", + "", + 0, + time.Second*2, + "Progress bar render delay", + `How long to wait after a long-running operation begins +before showing a progress bar.`, + ) + OptCompact = options.NewBool( "compact", "", + false, 'c', false, "Compact instead of pretty-printed output", @@ -138,6 +163,7 @@ as "RFC3339" or "Unix", or a strftime format such as "%Y-%m-%d %H:%M:%S". OptDatetimeFormatAsNumber = options.NewBool( "format.datetime.number", "", + false, 0, true, "Render numeric datetime value as number instead of string", @@ -175,6 +201,7 @@ from datetime values. In that situation, use format.datetime instead. OptDateFormatAsNumber = options.NewBool( "format.date.number", "", + false, 0, true, "Render numeric date value as number instead of string", @@ -211,6 +238,7 @@ from datetime values. In that situation, use format.datetime instead. OptTimeFormatAsNumber = options.NewBool( "format.time.number", "", + false, 0, true, "Render numeric time value as number instead of string", @@ -408,14 +436,13 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer errOut2 = colorable.NewNonColorable(errOut) } - if isTerminal(errOut) { - // FIXME: need to check option for --progress + if OptProgress.Get(opts) && isTerminal(errOut) { progColors := progress.DefaultColors() progColors.EnableColor(isColorTerminal(errOut)) ctx := cmd.Context() - // TODO: need to check for option "progress.delay" - prog := progress.New(ctx, errOut, time.Second*2, progColors) + renderDelay := OptProgressDelay.Get(opts) + prog := progress.New(ctx, errOut, renderDelay, progColors) out2 = prog.ShutdownOnWriteTo(out2) cmd.SetContext(progress.NewContext(ctx, prog)) logFrom(cmd).Debug("Initialized progress") diff --git a/cli/run.go b/cli/run.go index 802715831..6f5c1c5c3 100644 --- a/cli/run.go +++ b/cli/run.go @@ -124,8 +124,6 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { ru.Cleanup = cleanup.New() } - - cfg, log := ru.Config, lg.FromContext(ctx) var scratchSrcFunc driver.ScratchSrcFunc diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index cdd34ed86..61c3b9dc2 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -342,7 +342,7 @@ func TestDatetime(t *testing.T) { // TestIngestLargeCSV generates a large CSV file. // At count = 5000000, the generated file is ~500MB. func TestGenerateLargeCSV(t *testing.T) { - //t.Skip() + // t.Skip() const count = 5000000 // Generates ~500MB file start := time.Now() header := []string{ @@ -384,7 +384,7 @@ func TestGenerateLargeCSV(t *testing.T) { rec[3] = strconv.Itoa(rand.Intn(10)) // staff_id rec[4] = strconv.Itoa(i + 3) // rental_id, always unique f64 := amount.InexactFloat64() - //rec[5] = p.Sprintf("%.2f", f64) // amount + // rec[5] = p.Sprintf("%.2f", f64) // amount rec[5] = fmt.Sprintf("%.2f", f64) // amount amount = amount.Add(decimal.New(33, -2)) rec[6] = timez.TimestampUTC(paymentUTC) // payment_date diff --git a/drivers/csv/ingest.go b/drivers/csv/ingest.go index 6e8cf99f9..176ee3d7a 100644 --- a/drivers/csv/ingest.go +++ b/drivers/csv/ingest.go @@ -26,6 +26,7 @@ import ( var OptEmptyAsNull = options.NewBool( "driver.csv.empty-as-null", "", + false, 0, true, "Treat ingest empty CSV fields as NULL", diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index b222211ca..efd169315 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -146,10 +146,8 @@ func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, erro return nil, err } - if strings.Contains(fp, "checksum") { - x := true - _ = x - + if strings.Contains(fp, "checksum") { // FIXME: delete + lg.FromContext(ctx).Warn("This is bad") } db, err := sql.Open(dbDrvr, fp) diff --git a/go.mod b/go.mod index b67382873..952695027 100644 --- a/go.mod +++ b/go.mod @@ -37,6 +37,7 @@ require ( github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 + github.com/vbauerster/mpb/v8 v8.6.2 github.com/xo/dburl v0.18.2 github.com/xuri/excelize/v2 v2.8.0 go.uber.org/atomic v1.11.0 @@ -81,7 +82,6 @@ require ( github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/segmentio/asm v1.1.3 // indirect github.com/spf13/cast v1.5.1 // indirect - github.com/vbauerster/mpb/v8 v8.6.2 // indirect github.com/xuri/efp v0.0.0-20231025114914-d1ff6096ae53 // indirect github.com/xuri/nfp v0.0.0-20230919160717-d98342af3f05 // indirect golang.org/x/crypto v0.15.0 // indirect diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index a5bcfcd99..1f610071c 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -78,6 +78,19 @@ func logValue(err error) slog.Value { return slog.GroupValue(msgAttr, causeAttr, typeAttr) } +// IsErrContext returns true if err is context.Canceled or context.DeadlineExceeded. +func IsErrContext(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } + + return false +} + // IsErrContextDeadlineExceeded returns true if err is context.DeadlineExceeded. func IsErrContextDeadlineExceeded(err error) bool { return errors.Is(err, context.DeadlineExceeded) diff --git a/libsq/core/progress/contextio.go b/libsq/core/ioz/contextio/contextio.go similarity index 99% rename from libsq/core/progress/contextio.go rename to libsq/core/ioz/contextio/contextio.go index 03461ee9e..b0ae729f5 100644 --- a/libsq/core/progress/contextio.go +++ b/libsq/core/ioz/contextio/contextio.go @@ -16,7 +16,7 @@ limitations under the License. // This code is lifted from github.com/dolmen-go/contextio. -package progress +package contextio import ( "context" diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 28952a317..cb97e208f 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -4,10 +4,13 @@ package ioz import ( "bytes" "context" + crand "crypto/rand" "io" + mrand "math/rand" "os" "path/filepath" "strings" + "time" yaml "github.com/goccy/go-yaml" @@ -179,3 +182,63 @@ func FileAccessible(path string) bool { _, err := os.Stat(path) return err == nil } + +var _ io.Reader = (*delayReader)(nil) + +// DelayReader returns an io.Reader that delays on each read from r. +// This is primarily intended for testing. +// If jitter is true, a randomized jitter factor is added to the delay. +// If r implements io.Closer, the returned reader will also +// implement io.Closer; if r doesn't implement io.Closer, +// the returned reader will not implement io.Closer. +// If r is nil, nil is returned. +func DelayReader(r io.Reader, delay time.Duration, jitter bool) io.Reader { + if r == nil { + return nil + } + + dr := delayReader{r: r, delay: delay, jitter: jitter} + if _, ok := r.(io.Closer); ok { + return delayReadCloser{dr} + } + return dr +} + +var _ io.Reader = (*delayReader)(nil) + +type delayReader struct { + r io.Reader + delay time.Duration + jitter bool +} + +// Read implements io.Reader. +func (d delayReader) Read(p []byte) (n int, err error) { + delay := d.delay + if d.jitter { + delay += time.Duration(mrand.Int63n(int64(d.delay))) / 3 //nolint:gosec + } + + time.Sleep(delay) + return d.r.Read(p) +} + +var _ io.ReadCloser = (*delayReadCloser)(nil) + +type delayReadCloser struct { + delayReader +} + +// Close implements io.Closer. +func (d delayReadCloser) Close() error { + if c, ok := d.r.(io.Closer); ok { + return c.Close() + } + return nil +} + +// LimitRandReader returns an io.Reader that reads up to limit bytes +// from crypto/rand.Reader. +func LimitRandReader(limit int64) io.Reader { + return io.LimitReader(crand.Reader, limit) +} diff --git a/libsq/core/ioz/ioz_test.go b/libsq/core/ioz/ioz_test.go index 1aa1b84e2..6b1f5b24f 100644 --- a/libsq/core/ioz/ioz_test.go +++ b/libsq/core/ioz/ioz_test.go @@ -4,7 +4,9 @@ import ( "bytes" "io" "os" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -53,3 +55,28 @@ func TestChecksums(t *testing.T) { require.NoError(t, ioz.WriteChecksum(buf, gotSum1, f.Name())) require.NotEqual(t, gotSum1, gotSum2) } + +func TestDelayReader(t *testing.T) { + t.Parallel() + const ( + limit = 100000 + count = 15 + ) + + wg := &sync.WaitGroup{} + wg.Add(count) + for i := 0; i < count; i++ { + go func(i int) { + defer wg.Done() + randRdr := ioz.LimitRandReader(limit) + r := ioz.DelayReader(randRdr, 150*time.Millisecond, true) + start := time.Now() + _, err := io.ReadAll(r) + elapsed := time.Since(start) + t.Logf("%2d: Elapsed: %s", i, elapsed) + require.NoError(t, err) + }(i) + } + + wg.Wait() +} diff --git a/libsq/core/options/opt.go b/libsq/core/options/opt.go index 93e14bbc4..b4a9770fd 100644 --- a/libsq/core/options/opt.go +++ b/libsq/core/options/opt.go @@ -397,18 +397,31 @@ func (op Int) Process(o Options) (Options, error) { var _ Opt = Bool{} // NewBool returns an options.Bool instance. If flag is empty, the value -// of key is used. -func NewBool(key, flag string, short rune, defaultVal bool, usage, help string, tags ...string) Bool { +// of key is used. If invertFlag is true, the flag's boolean value +// is inverted to set the option. For example, if the Opt is "progress", +// and the flag is "--no-progress", then invertFlag should be true. +func NewBool(key, flag string, invertFlag bool, short rune, + defaultVal bool, usage, help string, tags ...string, +) Bool { return Bool{ - BaseOpt: NewBaseOpt(key, flag, short, usage, help, tags...), - defaultVal: defaultVal, + BaseOpt: NewBaseOpt(key, flag, short, usage, help, tags...), + defaultVal: defaultVal, + flagInverted: invertFlag, } } // Bool is an options.Opt for type bool. type Bool struct { BaseOpt - defaultVal bool + defaultVal bool + flagInverted bool +} + +// FlagInverted returns true Opt value is the inverse of the flag value. +// For example, if the Opt is "progress", and the flag is "--no-progress", +// then FlagInverted will return true. +func (op Bool) FlagInverted() bool { + return op.flagInverted } // GetAny implements options.Opt. diff --git a/libsq/core/options/options.go b/libsq/core/options/options.go index 597c95576..3a8d41877 100644 --- a/libsq/core/options/options.go +++ b/libsq/core/options/options.go @@ -192,7 +192,6 @@ func (o Options) Hash() string { buf.WriteString(k) v := o[k] buf.WriteString(fmt.Sprintf("%v", v)) - } sum := sha256.Sum256(buf.Bytes()) return fmt.Sprintf("%x", sum) diff --git a/libsq/core/options/options_test.go b/libsq/core/options/options_test.go index 34e77d499..9dc3b3c2f 100644 --- a/libsq/core/options/options_test.go +++ b/libsq/core/options/options_test.go @@ -116,7 +116,7 @@ func TestBool(t *testing.T) { t.Run(tutil.Name(i, tc.key), func(t *testing.T) { reg := &options.Registry{} - opt := options.NewBool(tc.key, "", 0, tc.defaultVal, "", "") + opt := options.NewBool(tc.key, "", false, 0, tc.defaultVal, "", "") reg.Add(opt) o := options.Options{tc.key: tc.input} diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index da2c7873b..e8af00514 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -2,13 +2,15 @@ package progress import ( "context" - "github.com/fatih/color" - "github.com/neilotoole/sq/libsq/core/cleanup" - "github.com/vbauerster/mpb/v8" - "github.com/vbauerster/mpb/v8/decor" "io" "sync" "time" + + "github.com/fatih/color" + mpb "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" + + "github.com/neilotoole/sq/libsq/core/cleanup" ) type runKey struct{} @@ -39,7 +41,7 @@ func FromContext(ctx context.Context) *Progress { return nil - //return ctx.Value(runKey{}).(*Progress) + // return ctx.Value(runKey{}).(*Progress) } func DefaultColors() *Colors { @@ -90,13 +92,6 @@ func (p *Progress) Wait() { p.p.Wait() } -//// Shutdown cancels any running bar immediately and then shutdowns `*Progress` -//// instance. Normally this method shouldn't be called unless you know what you -//// are doing. Proper way to shutdown is to call `(*Progress).Wait()` instead. -//func (p *Progress) Shutdown() { -// p.p.Shutdown() -//} - func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors) *Progress { p := mpb.NewWithContext(ctx, mpb.WithOutput(out), @@ -111,24 +106,14 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors return &Progress{p: p, colors: colors, mu: &sync.Mutex{}, cleanup: cleanup.New()} } -var _ io.Writer = (*writeNotifier)(nil) - -type writeNotifier struct { - p *Progress - w io.Writer -} - -func (w *writeNotifier) Write(p []byte) (n int, err error) { - w.p.Wait() - return w.w.Write(p) -} - -// ShutdownOnWriteTo returns a writer that will shut down the -// progress bar when w.WriteTo is called. Typically p writes +// ShutdownOnWriteTo returns a writer that will stop the +// progress.Progress when w.Write is called. Typically p writes // to stderr, and stdout is passed to this method. That is, when // the program starts writing to stdout, we want to shut down // and remove the progress bar. func (p *Progress) ShutdownOnWriteTo(w io.Writer) io.Writer { + // REVISIT: Should we check if w implements other io interfaces, + // such as io.WriterAt etc? Or do we really only care about io.Writer? if p == nil { return w } @@ -138,10 +123,24 @@ func (p *Progress) ShutdownOnWriteTo(w io.Writer) io.Writer { } } -// NewIOSpinner returns a new spinner bar. The caller is ultimately -// responsible for calling Finish() on the returned IOSpinner. However, -// the returned IOSpinner is added to the Progress's cleanup list, so -// it will be called automatically when the Progress is shut down. +var _ io.Writer = (*writeNotifier)(nil) + +type writeNotifier struct { + p *Progress + w io.Writer +} + +func (w *writeNotifier) Write(p []byte) (n int, err error) { + w.p.Wait() + return w.w.Write(p) +} + +// NewIOSpinner returns a new indeterminate spinner bar whose metric is +// the count of bytes processed. The caller is ultimately +// responsible for calling [IOSpinner.Stop] on the returned IOSpinner. However, +// the returned IOSpinner is also added to the Progress's cleanup list, so +// it will be called automatically when the Progress is shut down, but that +// may be later than the actual conclusion of the spinner's work. func (p *Progress) NewIOSpinner(msg string) *IOSpinner { if p == nil { return nil @@ -168,7 +167,7 @@ func (p *Progress) NewIOSpinner(msg string) *IOSpinner { ) spinner := &IOSpinner{bar: bar} - p.cleanup.Add(spinner.Finish) + p.cleanup.Add(spinner.Stop) return spinner } @@ -183,10 +182,11 @@ func (sp *IOSpinner) IncrBy(n int) { sp.bar.IncrBy(n) } -func (sp *IOSpinner) Finish() { +func (sp *IOSpinner) Stop() { if sp == nil { return } + sp.bar.SetTotal(-1, true) sp.bar.Wait() } diff --git a/libsq/core/progress/progress_test.go b/libsq/core/progress/progress_test.go index 9922214f2..633ca8ba6 100644 --- a/libsq/core/progress/progress_test.go +++ b/libsq/core/progress/progress_test.go @@ -1 +1,61 @@ package progress_test + +import ( + "bytes" + "context" + "io" + "os" + "testing" + "time" + + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewWriter(t *testing.T) { + const limit = 1000000 + + ctx := context.Background() + pb := progress.New(ctx, os.Stdout, time.Millisecond, progress.DefaultColors()) + ctx = progress.NewContext(ctx, pb) + + src := ioz.LimitRandReader(limit) + src = ioz.DelayReader(src, 10*time.Millisecond, true) + + dest := io.Discard + w := progress.NewWriter(ctx, "write test", dest) + + written, err := io.Copy(w, src) + require.NoError(t, err) + require.Equal(t, int64(limit), written) + pb.Wait() +} + +// TestNewWriter_Closer_type tests that the returned writer +// implements io.ReadCloser, or not, depending upon the type of +// the underlying writer. +func TestNewReader_Closer_type(t *testing.T) { + ctx := context.Background() + pb := progress.New(ctx, os.Stdout, time.Millisecond, progress.DefaultColors()) + ctx = progress.NewContext(ctx, pb) + defer pb.Wait() + + // bytes.Buffer doesn't implement io.Closer + buf := &bytes.Buffer{} + gotReader := progress.NewReader(ctx, "no closer", buf) + require.NotNil(t, gotReader) + _, isCloser := gotReader.(io.ReadCloser) + + assert.False(t, isCloser, "expected reader NOT to be io.ReadCloser but was %T", + gotReader) + + bufCloser := io.NopCloser(buf) + gotReader = progress.NewReader(ctx, "closer", bufCloser) + require.NotNil(t, gotReader) + _, isCloser = gotReader.(io.ReadCloser) + + assert.True(t, isCloser, "expected reader to be io.ReadCloser but was %T", + gotReader) +} diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index b9df949e6..9c194a152 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -14,140 +14,215 @@ See the License for the specific language governing permissions and limitations under the License. */ -// This code is lifted from github.com/dolmen-go/contextio. +// This code is derived from github.com/dolmen-go/contextio. package progress import ( "context" + "errors" "io" -) - -type progWriter struct { - ctx context.Context - w io.Writer - spinner *IOSpinner -} -type progCopier struct { - progWriter -} + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" +) -// NewProgWriter wraps an [io.Writer] to handle context cancellation. +// NewWriter returns an [io.Writer] that wraps w, is context-aware, and +// generates a progress bar as bytes are written to w. It is expected that ctx +// contains a *progress.Progress, as returned by progress.FromContext. If not, +// this function delegates to contextio.NewWriter: the returned writer will +// still be context-ware. See the contextio package for more details. // // Context state is checked BEFORE every Write. // -// The returned Writer also implements [io.ReaderFrom] to allow [io.Copy] to select +// The returned [io.Writer] implements [io.ReaderFrom] to allow [io.Copy] to select // the best strategy while still checking the context state before every chunk transfer. -func NewProgWriter(ctx context.Context, msg string, w io.Writer) io.Writer { +// +// The returned [io.Writer] also implements [io.Closer], even if the underlying +// writer does not. This is necessary because we need a means of stopping the +// progress bar when writing is complete. If the underlying writer does +// implement [io.Closer], it will be closed when the returned writer is closed. +func NewWriter(ctx context.Context, msg string, w io.Writer) io.Writer { if w, ok := w.(*progCopier); ok && ctx == w.ctx { return w } pb := FromContext(ctx) + if pb == nil { + return contextio.NewWriter(ctx, w) + } + spinner := pb.NewIOSpinner(msg) - return spinner.bar.ProxyWriter(w) - // FIXME: This is not working, ^^, bar stays on screen - //return &progCopier{progWriter{ctx: ctx, w: spinner.bar.ProxyWriter(w), spinner: spinner}} - //return &progCopier{progWriter{ctx: ctx, w: w, spinner: spinner}} + return &progCopier{progWriter{ + ctx: ctx, + w: spinner.bar.ProxyWriter(w), + spinner: spinner, + }} +} + +var _ io.WriteCloser = (*progWriter)(nil) + +type progWriter struct { + ctx context.Context + w io.Writer + spinner *IOSpinner } // Write implements [io.Writer], but with context awareness. func (w *progWriter) Write(p []byte) (n int, err error) { select { case <-w.ctx.Done(): - w.spinner.Finish() + w.spinner.Stop() return 0, w.ctx.Err() default: n, err = w.w.Write(p) - w.spinner.IncrBy(n) if err != nil { - w.spinner.Finish() + w.spinner.Stop() } return n, err } } +// Close implements [io.WriteCloser], but with context awareness. func (w *progWriter) Close() error { if w == nil { return nil } - w.spinner.Finish() - - // REVISIT: I'm not sure if we should always try - // to close the underlying writer first, even if - // the context is done? Or go straight to the - // select ctx.Done? + w.spinner.Stop() var closeErr error - if wc, ok := w.w.(io.WriteCloser); ok { - closeErr = wc.Close() + if c, ok := w.w.(io.Closer); ok { + closeErr = errz.Err(c.Close()) } select { case <-w.ctx.Done(): - return w.ctx.Err() + ctxErr := w.ctx.Err() + switch { + case closeErr == nil, + errz.IsErrContext(closeErr): + return ctxErr + default: + return errors.Join(ctxErr, closeErr) + } default: + return closeErr + } +} + +// NewReader returns an [io.Reader] that wraps r, is context-aware, and +// generates a progress bar as bytes are read from r. It is expected that ctx +// contains a *progress.Progress, as returned by progress.FromContext. If not, +// this function delegates to contextio.NewReader: the returned reader will +// still be context-ware. See the contextio package for more details. +// +// Context state is checked BEFORE every Read. +// +// The returned [io.Reader] also implements [io.Closer], even if the underlying +// reader does not. This is necessary because we need a means of stopping the +// progress bar when writing is complete. If the underlying reader does +// implement [io.Closer], it will be closed when the returned reader is closed. +func NewReader(ctx context.Context, msg string, r io.Reader) io.Reader { + if r, ok := r.(*progReader); ok && ctx == r.ctx { + return r + } + + pb := FromContext(ctx) + if pb == nil { + return contextio.NewReader(ctx, r) } - return closeErr + spinner := pb.NewIOSpinner(msg) + pr := &progReader{ + ctx: ctx, + r: spinner.bar.ProxyReader(r), + spinner: spinner, + } + return pr } +var _ io.ReadCloser = (*progReader)(nil) + type progReader struct { ctx context.Context r io.Reader spinner *IOSpinner } -// NewProgReader wraps an [io.Reader] to handle context cancellation. -// -// Context state is checked BEFORE every Read. -func NewProgReader(ctx context.Context, msg string, r io.Reader) io.Reader { - if r, ok := r.(*progReader); ok && ctx == r.ctx { - return r +// Close implements [io.ReadCloser], but with context awareness. +func (r *progReader) Close() error { + if r == nil { + return nil } - spinner := FromContext(ctx).NewIOSpinner(msg) - return &progReader{ctx: ctx, r: r, spinner: spinner} + r.spinner.Stop() + + var closeErr error + if c, ok := r.r.(io.ReadCloser); ok { + closeErr = errz.Err(c.Close()) + } + + select { + case <-r.ctx.Done(): + ctxErr := r.ctx.Err() + switch { + case closeErr == nil, + errz.IsErrContext(closeErr): + return ctxErr + default: + return errors.Join(ctxErr, closeErr) + } + default: + return closeErr + } } +// Read implements [io.Reader], but with context awareness. func (r *progReader) Read(p []byte) (n int, err error) { select { case <-r.ctx.Done(): - r.spinner.Finish() + r.spinner.Stop() return 0, r.ctx.Err() default: n, err = r.r.Read(p) - r.spinner.IncrBy(n) if err != nil { - r.spinner.Finish() + r.spinner.Stop() } return n, err } } +var _ io.ReaderFrom = (*progCopier)(nil) + +type progCopier struct { + progWriter +} + // ReadFrom implements interface [io.ReaderFrom], but with context awareness. // // This should allow efficient copying allowing writer or reader to define the chunk size. func (w *progCopier) ReadFrom(r io.Reader) (n int64, err error) { if _, ok := w.w.(io.ReaderFrom); ok { // Let the original Writer decide the chunk size. - // FIXME: Do we really need to pass the spinner to progReader, if - // the writer already has it? - return io.Copy(w.progWriter.w, &progReader{ctx: w.ctx, r: r, spinner: w.spinner}) + rdr := &progReader{ + ctx: w.ctx, + r: w.spinner.bar.ProxyReader(r), + spinner: w.spinner, + } + + return io.Copy(w.progWriter.w, rdr) } select { case <-w.ctx.Done(): - w.spinner.Finish() + w.spinner.Stop() return 0, w.ctx.Err() default: // The original Writer is not a ReaderFrom. // Let the Reader decide the chunk size. n, err = io.Copy(&w.progWriter, r) - w.spinner.IncrBy(int(n)) if err != nil { - w.spinner.Finish() + w.spinner.Stop() } return n, err } diff --git a/libsq/driver/ingest.go b/libsq/driver/ingest.go index 765d841b9..ce6249785 100644 --- a/libsq/driver/ingest.go +++ b/libsq/driver/ingest.go @@ -13,6 +13,7 @@ import ( var OptIngestHeader = options.NewBool( "ingest.header", "", + false, 0, false, "Ingest data has a header row", @@ -25,8 +26,9 @@ to detect the header.`, // OptIngestCache specifies whether ingested data is cached or not. var OptIngestCache = options.NewBool( - "ingest.bool", + "ingest.cache", "", + false, 0, true, "Ingest data is cached", diff --git a/libsq/driver/sources.go b/libsq/driver/sources.go index 2cb6144e7..eab56f302 100644 --- a/libsq/driver/sources.go +++ b/libsq/driver/sources.go @@ -207,7 +207,7 @@ func (d *Pools) openIngestCache(ctx context.Context, src *source.Source, } defer func() { log.Debug("About to release cache lock...", "lock", lock) - if err := lock.Unlock(); err != nil { + if err = lock.Unlock(); err != nil { log.Warn("Failed to release cache lock", "lock", lock, lga.Err, err) } else { log.Debug("Released cache lock", "lock", lock) @@ -266,7 +266,7 @@ func (d *Pools) openIngestCache(ctx context.Context, src *source.Source, if sum, err = ioz.FileChecksum(ingestFilePath); err != nil { log.Warn("Failed to compute checksum for source file; caching not in effect", lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) - return impl, nil + return impl, nil //nolint:nilerr } if err = ioz.WriteChecksumFile(checksumsPath, sum, ingestFilePath); err != nil { @@ -280,7 +280,7 @@ func (d *Pools) openIngestCache(ctx context.Context, src *source.Source, // getCachePaths returns the paths to the cache files for src. // There is no guarantee that these files exist, or are accessible. // It's just the paths. -func (d *Pools) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checksums string, err error) { //nolint:unparam +func (d *Pools) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checksums string, err error) { if srcCacheDir, err = source.CacheDirFor(src); err != nil { return "", "", "", err } diff --git a/libsq/source/files.go b/libsq/source/files.go index 55c91328c..1e82e0d6a 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -2,8 +2,6 @@ package source import ( "context" - "github.com/dolmen-go/contextio" - "github.com/neilotoole/sq/libsq/core/progress" "io" "log/slog" "mime" @@ -13,6 +11,8 @@ import ( "sync" "time" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/h2non/filetype" "github.com/h2non/filetype/matchers" "golang.org/x/sync/errgroup" @@ -21,9 +21,11 @@ import ( "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/libsq/source/fetcher" @@ -150,38 +152,31 @@ func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.R lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) return nil, errz.Errorf("failed to add to fscache (possibly previously added): %s", key) } - - copier := fscache.Filler{ + // TODO: Problematically, we copy the entire contents of f into fscache. + // This is probably necessary for piped data on stdin, but for files + // that already exist on the file system, it would be nice if the cacheFile + // could be mapped directly to the filesystem file. This might require + // hacking on the fscache impl. + copier := fscache.AsyncFiller{ Message: "Cache fill", Log: log.With(lga.Action, "Cache fill"), - NewContextWriterFn: progress.NewProgWriter, + NewContextWriterFn: progress.NewWriter, + // We don't use progress.NewReader here, because that + // would result in double counting of bytes transferred. NewContextReaderFn: func(ctx context.Context, msg string, r io.Reader) io.Reader { return contextio.NewReader(ctx, r) }, CloseReader: true, } - // TODO: Problematically, we copy the entire contents of f into fscache. - // If f is a large file (e.g. piped over stdin), this means that - // everything is held up until f is fully copied. Hopefully we can - // do something with fscache so that the readers returned from - // fscache can lazily read from f. - //if err = fscache.FillWriterAsync(ctx, log, w, f, true); err != nil { - // lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - // return nil, errz.Err(err) - //} - - if err = copier.Copy(ctx, w, f); err != nil { + // FIXME: Added a delay for testing. Remove this before release. + df := ioz.DelayReader(f, time.Millisecond, true) + // if err = copier.Copy(ctx, w, f); err != nil { + if err = copier.Copy(ctx, w, df); err != nil { lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) return nil, errz.Err(err) } - //err = errz.Combine(w.Close(), f.Close()) - //if err != nil { - // lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - // return nil, err - //} - return r, nil } @@ -205,57 +200,7 @@ func (fs *Files) Filepath(_ context.Context, src *Source) (string, error) { _ = u // It's a remote file. We really should download it here. // FIXME: implement downloading. - return "", errz.Errorf("Filepath not implemented for remote files: %s", loc) - // - //if ; !ok { - // // It's not a filepath, and it's not a http URL, - // // so we need to download it. - // - // - //} - // - //return "", - // - // - // - //typ, err := fs.DriverType(ctx, src.Location) - //if err != nil { - // return "", err - //} - // - //if !fs.fcache.Exists(loc) { - // // cache miss - // f, err := fs.openLocation(loc) - // if err != nil { - // return "", err - // } - // - // // Note that addFile closes f - // _, err = fs.addFile(f, loc) - // if err != nil { - // return "", err - // } - // return f.Name(), nil - //} - // - //return loc, nil - //r, _, err := fs.fcache.Get(loc) - //if err != nil { - // return "", err - //} - // - //return r, nil - - // // cache miss - // f, err := fs.openLocation(src.Location) - // if err != nil { - // return "", err - // } - // - // if err = f.Close(); err != nil { - // return "", errz.Err(err) - // } - // return f.Name(), nil + return "", errz.Errorf("Files.Filepath not implemented for remote files: %s", loc) } // Open returns a new io.ReadCloser for src.Location. @@ -333,7 +278,7 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro // openLocation returns a file for loc. It is the caller's // responsibility to close the returned file. -func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) { +func (fs *Files) openLocation(_ context.Context, loc string) (*os.File, error) { var fpath string var ok bool var err error @@ -349,6 +294,7 @@ func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) } // It's a remote file + // TODO: fetch should take ctx to allow for cancellation fpath, err = fs.fetch(u.String()) if err != nil { return nil, err diff --git a/libsq/source/source.go b/libsq/source/source.go index 4092331f4..32f3cfc55 100644 --- a/libsq/source/source.go +++ b/libsq/source/source.go @@ -212,6 +212,9 @@ func RedactLocation(loc string) string { case loc == "", strings.HasPrefix(loc, "/"), strings.HasPrefix(loc, "sqlite3://"): + + // REVISIT: If it's a sqlite URI, could it have auth details in there? + // e.g. "?_auth_pass=foo" return loc case strings.HasPrefix(loc, "http://"), strings.HasPrefix(loc, "https://"): u, err := url.ParseRequestURI(loc) From b281bb21f3c618775ef15fa9897fc88e432fc329 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Mon, 27 Nov 2023 21:50:41 -0700 Subject: [PATCH 021/195] wip: close to working --- cli/cmd_slq.go | 1 + cli/cmd_sql.go | 1 + cli/cmd_xtest.go | 2 +- drivers/csv/ingest.go | 1 + drivers/json/import_jsona.go | 1 + drivers/xlsx/ingest.go | 2 +- libsq/core/progress/progress.go | 28 ++++++++++++++++++++++++---- libsq/dbwriter.go | 9 +++++---- libsq/driver/driver_test.go | 2 +- libsq/driver/record.go | 18 ++++++++++++------ libsq/pipeline.go | 1 + testh/testh.go | 2 +- 12 files changed, 50 insertions(+), 18 deletions(-) diff --git a/cli/cmd_slq.go b/cli/cmd_slq.go index 16c2bc1be..ea2477b46 100644 --- a/cli/cmd_slq.go +++ b/cli/cmd_slq.go @@ -146,6 +146,7 @@ func execSLQInsert(ctx context.Context, ru *run.Run, mArgs map[string]string, // stack. inserter := libsq.NewDBWriter( + "Insert records", destPool, destTbl, driver.OptTuningRecChanSize.Get(destSrc.Options), diff --git a/cli/cmd_sql.go b/cli/cmd_sql.go index 56a448f18..8801eb634 100644 --- a/cli/cmd_sql.go +++ b/cli/cmd_sql.go @@ -159,6 +159,7 @@ func execSQLInsert(ctx context.Context, ru *run.Run, // is invoked by ru.Close, and ru is closed further up the // stack. inserter := libsq.NewDBWriter( + "Insert records", destPool, destTbl, driver.OptTuningRecChanSize.Get(destSrc.Options), diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go index 8ca195f66..8ed54c258 100644 --- a/cli/cmd_xtest.go +++ b/cli/cmd_xtest.go @@ -51,7 +51,7 @@ func execXTestMbp(cmd *cobra.Command, _ []string) error { func doBigRead2(ctx context.Context) error { pb := progress.FromContext(ctx) - spinner := pb.NewIOSpinner("Ingest data test...") + spinner := pb.NewIOSpinner("Ingest data test") defer spinner.Stop() maxSleep := 100 * time.Millisecond diff --git a/drivers/csv/ingest.go b/drivers/csv/ingest.go index 176ee3d7a..87d149337 100644 --- a/drivers/csv/ingest.go +++ b/drivers/csv/ingest.go @@ -127,6 +127,7 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu } insertWriter := libsq.NewDBWriter( + "Ingest records", scratchPool, tblDef.Name, driver.OptTuningRecChanSize.Get(scratchPool.Source().Options), diff --git a/drivers/json/import_jsona.go b/drivers/json/import_jsona.go index 91feff792..498b8b89b 100644 --- a/drivers/json/import_jsona.go +++ b/drivers/json/import_jsona.go @@ -143,6 +143,7 @@ func importJSONA(ctx context.Context, job importJob) error { defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) insertWriter := libsq.NewDBWriter( + "Ingest records", job.destPool, tblDef.Name, driver.OptTuningRecChanSize.Get(job.destPool.Source().Options), diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index ca8223512..f273cd01b 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -189,7 +189,7 @@ func ingestSheetToTable(ctx context.Context, scratchPool driver.Pool, sheetTbl * drvr := scratchPool.SQLDriver() batchSize := driver.MaxBatchRows(drvr, len(destColKinds)) - bi, err := driver.NewBatchInsert(ctx, drvr, conn, tblDef.Name, tblDef.ColNames(), batchSize) + bi, err := driver.NewBatchInsert(ctx, "Ingest records", drvr, conn, tblDef.Name, tblDef.ColNames(), batchSize) if err != nil { return err } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index e8af00514..63084979b 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -2,6 +2,8 @@ package progress import ( "context" + "fmt" + "github.com/neilotoole/sq/libsq/core/stringz" "io" "sync" "time" @@ -126,15 +128,26 @@ func (p *Progress) ShutdownOnWriteTo(w io.Writer) io.Writer { var _ io.Writer = (*writeNotifier)(nil) type writeNotifier struct { - p *Progress - w io.Writer + p *Progress + w io.Writer + notifyOnce sync.Once } +// Write implements [io.Writer]. func (w *writeNotifier) Write(p []byte) (n int, err error) { - w.p.Wait() + w.notifyOnce.Do(w.p.Wait) + return w.w.Write(p) } +func normalizeMsgLength(msg string, length int) string { + if len(msg) > length { + msg = stringz.TrimLenMiddle(msg, length) + } + + return fmt.Sprintf("%-*s", length, msg) +} + // NewIOSpinner returns a new indeterminate spinner bar whose metric is // the count of bytes processed. The caller is ultimately // responsible for calling [IOSpinner.Stop] on the returned IOSpinner. However, @@ -149,6 +162,13 @@ func (p *Progress) NewIOSpinner(msg string) *IOSpinner { p.mu.Lock() defer p.mu.Unlock() + const ( + msgLength = 18 + barWidth = 28 + ) + + msg = normalizeMsgLength(msg, msgLength) + style := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") style = style.Meta(func(s string) string { return p.colors.Spinner.Sprint(s) @@ -156,7 +176,7 @@ func (p *Progress) NewIOSpinner(msg string) *IOSpinner { bar := p.p.New(0, style, - mpb.BarWidth(36), + mpb.BarWidth(barWidth), mpb.PrependDecorators( ColorMeta(decor.Name(msg), p.colors.Message), ), diff --git a/libsq/dbwriter.go b/libsq/dbwriter.go index 5f2adbcf7..9d1ca3781 100644 --- a/libsq/dbwriter.go +++ b/libsq/dbwriter.go @@ -19,6 +19,7 @@ import ( // DBWriter implements RecordWriter, writing // records to a database table. type DBWriter struct { + msg string wg *sync.WaitGroup cancelFn context.CancelFunc destPool driver.Pool @@ -73,10 +74,10 @@ func DBWriterCreateTableIfNotExistsHook(destTblName string) DBWriterPreWriteHook // The writer writes records from recordCh to destTbl // in destPool. The recChSize param controls the size of recordCh // returned by the writer's Open method. -func NewDBWriter(destPool driver.Pool, destTbl string, recChSize int, - preWriteHooks ...DBWriterPreWriteHook, -) *DBWriter { +func NewDBWriter(msg string, destPool driver.Pool, destTbl string, recChSize int, + preWriteHooks ...DBWriterPreWriteHook) *DBWriter { return &DBWriter{ + msg: msg, destPool: destPool, destTbl: destTbl, recordCh: make(chan record.Record, recChSize), @@ -117,7 +118,7 @@ func (w *DBWriter) Open(ctx context.Context, cancelFn context.CancelFunc, recMet } batchSize := driver.MaxBatchRows(w.destPool.SQLDriver(), len(recMeta.Names())) - w.bi, err = driver.NewBatchInsert(ctx, w.destPool.SQLDriver(), tx, w.destTbl, recMeta.Names(), batchSize) + w.bi, err = driver.NewBatchInsert(ctx, w.msg, w.destPool.SQLDriver(), tx, w.destTbl, recMeta.Names(), batchSize) if err != nil { w.rollback(ctx, tx, err) return nil, nil, err diff --git a/libsq/driver/driver_test.go b/libsq/driver/driver_test.go index ddbe68cd2..14a9cc315 100644 --- a/libsq/driver/driver_test.go +++ b/libsq/driver/driver_test.go @@ -294,7 +294,7 @@ func TestNewBatchInsert(t *testing.T) { // Get records from TblActor that we'll write to the new tbl recMeta, recs := testh.RecordsFromTbl(t, handle, sakila.TblActor) - bi, err := driver.NewBatchInsert(th.Context, drvr, conn, tblName, recMeta.Names(), batchSize) + bi, err := driver.NewBatchInsert(th.Context, "Insert records", drvr, conn, tblName, recMeta.Names(), batchSize) require.NoError(t, err) for _, rec := range recs { diff --git a/libsq/driver/record.go b/libsq/driver/record.go index 89da6659f..5d3f0e538 100644 --- a/libsq/driver/record.go +++ b/libsq/driver/record.go @@ -5,6 +5,7 @@ import ( "context" "database/sql" "fmt" + "github.com/neilotoole/sq/libsq/core/progress" "math" "reflect" "strings" @@ -370,7 +371,7 @@ func (bi *BatchInsert) Written() int64 { // Munge should be invoked on every record before sending // on RecordCh. -func (bi BatchInsert) Munge(rec []any) error { +func (bi *BatchInsert) Munge(rec []any) error { return bi.mungeFn(rec) } @@ -381,16 +382,16 @@ func (bi BatchInsert) Munge(rec []any) error { // it must be a sql.Conn or sql.Tx. // //nolint:gocognit -func NewBatchInsert(ctx context.Context, drvr SQLDriver, db sqlz.DB, - destTbl string, destColNames []string, batchSize int, -) (*BatchInsert, error) { +func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, + destTbl string, destColNames []string, batchSize int) (*BatchInsert, error) { log := lg.FromContext(ctx) - err := sqlz.RequireSingleConn(db) - if err != nil { + if err := sqlz.RequireSingleConn(db); err != nil { return nil, err } + spinner := progress.FromContext(ctx).NewIOSpinner(msg) + recCh := make(chan []any, batchSize*8) errCh := make(chan error, 1) rowLen := len(destColNames) @@ -412,6 +413,8 @@ func NewBatchInsert(ctx context.Context, drvr SQLDriver, db sqlz.DB, var affected int64 defer func() { + spinner.Stop() + if inserter != nil { if err == nil { // If no pre-existing error, any inserter.Close error @@ -423,6 +426,7 @@ func NewBatchInsert(ctx context.Context, drvr SQLDriver, db sqlz.DB, // is the primary concern. lg.WarnIfError(log, lgm.CloseDBStmt, errz.Err(inserter.Close())) } + } if err != nil { @@ -463,6 +467,7 @@ func NewBatchInsert(ctx context.Context, drvr SQLDriver, db sqlz.DB, } bi.written.Add(affected) + spinner.IncrBy(int(affected)) if rec == nil { // recCh is closed (coincidentally exactly on the @@ -505,6 +510,7 @@ func NewBatchInsert(ctx context.Context, drvr SQLDriver, db sqlz.DB, } bi.written.Add(affected) + spinner.IncrBy(int(affected)) // We're done return diff --git a/libsq/pipeline.go b/libsq/pipeline.go index 9a724c6a4..126298900 100644 --- a/libsq/pipeline.go +++ b/libsq/pipeline.go @@ -464,6 +464,7 @@ func execCopyTable(ctx context.Context, fromDB driver.Pool, fromTbl tablefq.T, } inserter := NewDBWriter( + "Copy records", destPool, destTbl.Table, driver.OptTuningRecChanSize.Get(destPool.Source().Options), diff --git a/testh/testh.go b/testh/testh.go index 9e18d9c33..f5b1d4b15 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -485,7 +485,7 @@ func (h *Helper) Insert(src *source.Source, tbl string, cols []string, records . defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, conn) batchSize := driver.MaxBatchRows(drvr, len(cols)) - bi, err := driver.NewBatchInsert(h.Context, drvr, conn, tbl, cols, batchSize) + bi, err := driver.NewBatchInsert(h.Context, "Insert records", drvr, conn, tbl, cols, batchSize) require.NoError(h.T, err) for _, rec := range records { From 9ada0a09e049e571c30510cd0a7f0e169ecf79d8 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Mon, 27 Nov 2023 23:00:24 -0700 Subject: [PATCH 022/195] wip: close to working --- cli/cmd_xtest.go | 2 +- cli/output.go | 2 +- go.mod | 1 + go.sum | 2 + libsq/core/progress/progress.go | 79 ++++++++++++++++++++++--------- libsq/core/progress/progressio.go | 8 ++-- libsq/driver/record.go | 2 +- 7 files changed, 67 insertions(+), 29 deletions(-) diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go index 8ed54c258..c9525c158 100644 --- a/cli/cmd_xtest.go +++ b/cli/cmd_xtest.go @@ -51,7 +51,7 @@ func execXTestMbp(cmd *cobra.Command, _ []string) error { func doBigRead2(ctx context.Context) error { pb := progress.FromContext(ctx) - spinner := pb.NewIOSpinner("Ingest data test") + spinner := pb.NewByteCounterSpinner("Ingest data test") defer spinner.Stop() maxSleep := 100 * time.Millisecond diff --git a/cli/output.go b/cli/output.go index a96b43e51..810dbfd9b 100644 --- a/cli/output.go +++ b/cli/output.go @@ -443,7 +443,7 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer ctx := cmd.Context() renderDelay := OptProgressDelay.Get(opts) prog := progress.New(ctx, errOut, renderDelay, progColors) - out2 = prog.ShutdownOnWriteTo(out2) + out2 = progress.ShutdownOnWriteTo(prog, out2) cmd.SetContext(progress.NewContext(ctx, prog)) logFrom(cmd).Debug("Initialized progress") } diff --git a/go.mod b/go.mod index 952695027..ccc4a9ab8 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/djherbis/atime v1.1.0 // indirect github.com/djherbis/stream v1.4.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/huandu/xstrings v1.4.0 // indirect diff --git a/go.sum b/go.sum index 891c94870..ecd3e4b48 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ github.com/djherbis/atime v1.1.0 h1:rgwVbP/5by8BvvjBNrbh64Qz33idKT3pSnMSJsxhi0g= github.com/djherbis/atime v1.1.0/go.mod h1:28OF6Y8s3NQWwacXc5eZTsEsiMzp7LF8MbXE+XJPdBE= github.com/djherbis/stream v1.4.0 h1:aVD46WZUiq5kJk55yxJAyw6Kuera6kmC3i2vEQyW/AE= github.com/djherbis/stream v1.4.0/go.mod h1:cqjC1ZRq3FFwkGmUtHwcldbnW8f0Q4YuVsGW1eAFtOk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ecnepsnai/osquery v1.0.1 h1:i96n/3uqcafKZtRYmXVNqekKbfrIm66q179mWZ/Y2Aw= github.com/ecnepsnai/osquery v1.0.1/go.mod h1:vxsezNRznmkLa8UjVh88tlJiRbgW7iwinkjyg/Xc2RU= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 63084979b..2a56cf378 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -3,6 +3,8 @@ package progress import ( "context" "fmt" + "github.com/dustin/go-humanize" + "github.com/dustin/go-humanize/english" "github.com/neilotoole/sq/libsq/core/stringz" "io" "sync" @@ -108,12 +110,14 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors return &Progress{p: p, colors: colors, mu: &sync.Mutex{}, cleanup: cleanup.New()} } -// ShutdownOnWriteTo returns a writer that will stop the +// ShutdownOnWriteTo returns a writer decorator that stop the // progress.Progress when w.Write is called. Typically p writes -// to stderr, and stdout is passed to this method. That is, when +// to stderr, but stdout is passed to this method. That is, when // the program starts writing to stdout, we want to shut down // and remove the progress bar. -func (p *Progress) ShutdownOnWriteTo(w io.Writer) io.Writer { +// +// REVISIT: ShutdownOnWriteTo is not a great name. +func ShutdownOnWriteTo(p *Progress, w io.Writer) io.Writer { // REVISIT: Should we check if w implements other io interfaces, // such as io.WriterAt etc? Or do we really only care about io.Writer? if p == nil { @@ -133,7 +137,8 @@ type writeNotifier struct { notifyOnce sync.Once } -// Write implements [io.Writer]. +// Write implements [io.Writer]. On first invocation, the referenced +// progress.Progress is stopped via its [Progress.Wait] method. func (w *writeNotifier) Write(p []byte) (n int, err error) { w.notifyOnce.Do(w.p.Wait) @@ -148,13 +153,13 @@ func normalizeMsgLength(msg string, length int) string { return fmt.Sprintf("%-*s", length, msg) } -// NewIOSpinner returns a new indeterminate spinner bar whose metric is -// the count of bytes processed. The caller is ultimately -// responsible for calling [IOSpinner.Stop] on the returned IOSpinner. However, -// the returned IOSpinner is also added to the Progress's cleanup list, so +// NewCountSpinner returns a new indeterminate spinner bar whose label +// metric is the provided unit. The caller is ultimately +// responsible for calling [Spinner.Stop] on the returned Spinner. However, +// the returned Spinner is also added to the Progress's cleanup list, so // it will be called automatically when the Progress is shut down, but that // may be later than the actual conclusion of the spinner's work. -func (p *Progress) NewIOSpinner(msg string) *IOSpinner { +func (p *Progress) NewCountSpinner(msg, unit string) *Spinner { if p == nil { return nil } @@ -162,12 +167,44 @@ func (p *Progress) NewIOSpinner(msg string) *IOSpinner { p.mu.Lock() defer p.mu.Unlock() - const ( - msgLength = 18 - barWidth = 28 - ) + counter := decor.Any(func(statistics decor.Statistics) string { + s := humanize.Comma(statistics.Current) + if unit != "" { + s += " " + english.PluralWord(int(statistics.Current), unit, "") + } + return s + }) - msg = normalizeMsgLength(msg, msgLength) + decorators := []decor.Decorator{ColorMeta(counter, p.colors.Size)} + return p.newSpinner(msg, decorators...) +} + +// NewByteCounterSpinner returns a new indeterminate spinner bar whose +// metric is the count of bytes processed. The caller is ultimately +// responsible for calling [Spinner.Stop] on the returned Spinner. However, +// the returned Spinner is also added to the Progress's cleanup list, so +// it will be called automatically when the Progress is shut down, but that +// may be later than the actual conclusion of the spinner's work. +func (p *Progress) NewByteCounterSpinner(msg string) *Spinner { + if p == nil { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + decorators := []decor.Decorator{ + ColorMeta(decor.Current(decor.SizeB1024(0), "% .1f"), p.colors.Size), + } + return p.newSpinner(msg, decorators...) +} + +func (p *Progress) newSpinner(msg string, decorators ...decor.Decorator) *Spinner { + if p == nil { + return nil + } + + const barWidth = 28 style := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") style = style.Meta(func(s string) string { @@ -178,31 +215,29 @@ func (p *Progress) NewIOSpinner(msg string) *IOSpinner { style, mpb.BarWidth(barWidth), mpb.PrependDecorators( - ColorMeta(decor.Name(msg), p.colors.Message), - ), - mpb.AppendDecorators( - ColorMeta(decor.Current(decor.SizeB1024(0), "% .1f"), p.colors.Message), + ColorMeta(decor.Name(msg, decor.WCSyncWidthR), p.colors.Message), ), + mpb.AppendDecorators(decorators...), mpb.BarRemoveOnComplete(), ) - spinner := &IOSpinner{bar: bar} + spinner := &Spinner{bar: bar} p.cleanup.Add(spinner.Stop) return spinner } -type IOSpinner struct { +type Spinner struct { bar *mpb.Bar } -func (sp *IOSpinner) IncrBy(n int) { +func (sp *Spinner) IncrBy(n int) { if sp == nil { return } sp.bar.IncrBy(n) } -func (sp *IOSpinner) Stop() { +func (sp *Spinner) Stop() { if sp == nil { return } diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index 9c194a152..143002c9c 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -52,7 +52,7 @@ func NewWriter(ctx context.Context, msg string, w io.Writer) io.Writer { return contextio.NewWriter(ctx, w) } - spinner := pb.NewIOSpinner(msg) + spinner := pb.NewByteCounterSpinner(msg) return &progCopier{progWriter{ ctx: ctx, w: spinner.bar.ProxyWriter(w), @@ -65,7 +65,7 @@ var _ io.WriteCloser = (*progWriter)(nil) type progWriter struct { ctx context.Context w io.Writer - spinner *IOSpinner + spinner *Spinner } // Write implements [io.Writer], but with context awareness. @@ -133,7 +133,7 @@ func NewReader(ctx context.Context, msg string, r io.Reader) io.Reader { return contextio.NewReader(ctx, r) } - spinner := pb.NewIOSpinner(msg) + spinner := pb.NewByteCounterSpinner(msg) pr := &progReader{ ctx: ctx, r: spinner.bar.ProxyReader(r), @@ -147,7 +147,7 @@ var _ io.ReadCloser = (*progReader)(nil) type progReader struct { ctx context.Context r io.Reader - spinner *IOSpinner + spinner *Spinner } // Close implements [io.ReadCloser], but with context awareness. diff --git a/libsq/driver/record.go b/libsq/driver/record.go index 5d3f0e538..01c7929b9 100644 --- a/libsq/driver/record.go +++ b/libsq/driver/record.go @@ -390,7 +390,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, return nil, err } - spinner := progress.FromContext(ctx).NewIOSpinner(msg) + spinner := progress.FromContext(ctx).NewCountSpinner(msg, "rec") recCh := make(chan []any, batchSize*8) errCh := make(chan error, 1) From e133d4bb211d519bf696776a439300c34ee6455a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Mon, 27 Nov 2023 23:32:46 -0700 Subject: [PATCH 023/195] Largely there --- cli/cmd_xtest.go | 5 ++- drivers/csv/ingest.go | 2 +- drivers/json/import_jsona.go | 2 +- drivers/xlsx/ingest.go | 11 ++++++- libsq/core/progress/progress.go | 46 +++++++++++++++++----------- libsq/core/progress/progress_test.go | 11 ++++--- libsq/core/progress/progressio.go | 10 +++--- libsq/dbwriter.go | 17 ++++++++-- libsq/driver/driver_test.go | 10 +++++- libsq/driver/record.go | 6 ++-- libsq/source/files.go | 21 +++++++++---- testh/testh.go | 10 +++++- 12 files changed, 105 insertions(+), 46 deletions(-) diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go index c9525c158..625711535 100644 --- a/cli/cmd_xtest.go +++ b/cli/cmd_xtest.go @@ -6,14 +6,13 @@ import ( "math/rand" "time" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/buildinfo" "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/hostinfo" "github.com/neilotoole/sq/cli/run" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/progress" ) @@ -51,7 +50,7 @@ func execXTestMbp(cmd *cobra.Command, _ []string) error { func doBigRead2(ctx context.Context) error { pb := progress.FromContext(ctx) - spinner := pb.NewByteCounterSpinner("Ingest data test") + spinner := pb.NewByteCounterSpinner("Ingest data test", -1) defer spinner.Stop() maxSleep := 100 * time.Millisecond diff --git a/drivers/csv/ingest.go b/drivers/csv/ingest.go index 87d149337..2fe15db08 100644 --- a/drivers/csv/ingest.go +++ b/drivers/csv/ingest.go @@ -127,7 +127,7 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu } insertWriter := libsq.NewDBWriter( - "Ingest records", + libsq.MsgIngestRecords, scratchPool, tblDef.Name, driver.OptTuningRecChanSize.Get(scratchPool.Source().Options), diff --git a/drivers/json/import_jsona.go b/drivers/json/import_jsona.go index 498b8b89b..01ae0401f 100644 --- a/drivers/json/import_jsona.go +++ b/drivers/json/import_jsona.go @@ -143,7 +143,7 @@ func importJSONA(ctx context.Context, job importJob) error { defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) insertWriter := libsq.NewDBWriter( - "Ingest records", + libsq.MsgIngestRecords, job.destPool, tblDef.Name, driver.OptTuningRecChanSize.Get(job.destPool.Source().Options), diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index f273cd01b..1e60acebf 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -12,6 +12,7 @@ import ( excelize "github.com/xuri/excelize/v2" "golang.org/x/sync/errgroup" + "github.com/neilotoole/sq/libsq" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" @@ -189,7 +190,15 @@ func ingestSheetToTable(ctx context.Context, scratchPool driver.Pool, sheetTbl * drvr := scratchPool.SQLDriver() batchSize := driver.MaxBatchRows(drvr, len(destColKinds)) - bi, err := driver.NewBatchInsert(ctx, "Ingest records", drvr, conn, tblDef.Name, tblDef.ColNames(), batchSize) + bi, err := driver.NewBatchInsert( + ctx, + libsq.MsgIngestRecords, + drvr, + conn, + tblDef.Name, + tblDef.ColNames(), + batchSize, + ) if err != nil { return err } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 2a56cf378..0c86f4b15 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -3,18 +3,18 @@ package progress import ( "context" "fmt" - "github.com/dustin/go-humanize" - "github.com/dustin/go-humanize/english" - "github.com/neilotoole/sq/libsq/core/stringz" "io" "sync" "time" + humanize "github.com/dustin/go-humanize" + "github.com/dustin/go-humanize/english" "github.com/fatih/color" mpb "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" "github.com/neilotoole/sq/libsq/core/cleanup" + "github.com/neilotoole/sq/libsq/core/stringz" ) type runKey struct{} @@ -81,7 +81,7 @@ func (c *Colors) EnableColor(enable bool) { type Progress struct { p *mpb.Progress - mu *sync.Mutex + mu *sync.Mutex // FIXME: we don't need mu? colors *Colors cleanup *cleanup.Cleanup } @@ -167,7 +167,7 @@ func (p *Progress) NewCountSpinner(msg, unit string) *Spinner { p.mu.Lock() defer p.mu.Unlock() - counter := decor.Any(func(statistics decor.Statistics) string { + decorator := decor.Any(func(statistics decor.Statistics) string { s := humanize.Comma(statistics.Current) if unit != "" { s += " " + english.PluralWord(int(statistics.Current), unit, "") @@ -175,17 +175,17 @@ func (p *Progress) NewCountSpinner(msg, unit string) *Spinner { return s }) - decorators := []decor.Decorator{ColorMeta(counter, p.colors.Size)} - return p.newSpinner(msg, decorators...) + decorator = ColorMeta(decorator, p.colors.Size) + return p.newSpinner(msg, -1, decorator) } -// NewByteCounterSpinner returns a new indeterminate spinner bar whose -// metric is the count of bytes processed. The caller is ultimately -// responsible for calling [Spinner.Stop] on the returned Spinner. However, -// the returned Spinner is also added to the Progress's cleanup list, so -// it will be called automatically when the Progress is shut down, but that +// NewByteCounterSpinner returns a new spinner bar whose metric is the count +// of bytes processed. If the size is unknown, set arg size to -1. The caller +// is ultimately responsible for calling [Spinner.Stop] on the returned Spinner. +// However, the returned Spinner is also added to the Progress's cleanup list, +// so it will be called automatically when the Progress is shut down, but that // may be later than the actual conclusion of the spinner's work. -func (p *Progress) NewByteCounterSpinner(msg string) *Spinner { +func (p *Progress) NewByteCounterSpinner(msg string, size int64) *Spinner { if p == nil { return nil } @@ -193,25 +193,35 @@ func (p *Progress) NewByteCounterSpinner(msg string) *Spinner { p.mu.Lock() defer p.mu.Unlock() - decorators := []decor.Decorator{ - ColorMeta(decor.Current(decor.SizeB1024(0), "% .1f"), p.colors.Size), + var decorator decor.Decorator + if size < 0 { + decorator = decor.CountersNoUnit("% .2f / ?") + } else { + // decorator = decor.Current(decor.SizeB1024(0), "% .1f") + decorator = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") } - return p.newSpinner(msg, decorators...) + decorator = ColorMeta(decorator, p.colors.Size) + + return p.newSpinner(msg, size, decorator) } -func (p *Progress) newSpinner(msg string, decorators ...decor.Decorator) *Spinner { +func (p *Progress) newSpinner(msg string, total int64, decorators ...decor.Decorator) *Spinner { if p == nil { return nil } const barWidth = 28 + if total < 0 { + total = 0 + } + style := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") style = style.Meta(func(s string) string { return p.colors.Spinner.Sprint(s) }) - bar := p.p.New(0, + bar := p.p.New(total, style, mpb.BarWidth(barWidth), mpb.PrependDecorators( diff --git a/libsq/core/progress/progress_test.go b/libsq/core/progress/progress_test.go index 633ca8ba6..4be4bca57 100644 --- a/libsq/core/progress/progress_test.go +++ b/libsq/core/progress/progress_test.go @@ -8,10 +8,11 @@ import ( "testing" "time" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/progress" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/progress" ) func TestNewWriter(t *testing.T) { @@ -25,7 +26,7 @@ func TestNewWriter(t *testing.T) { src = ioz.DelayReader(src, 10*time.Millisecond, true) dest := io.Discard - w := progress.NewWriter(ctx, "write test", dest) + w := progress.NewWriter(ctx, "write test", -1, dest) written, err := io.Copy(w, src) require.NoError(t, err) @@ -44,7 +45,7 @@ func TestNewReader_Closer_type(t *testing.T) { // bytes.Buffer doesn't implement io.Closer buf := &bytes.Buffer{} - gotReader := progress.NewReader(ctx, "no closer", buf) + gotReader := progress.NewReader(ctx, "no closer", -1, buf) require.NotNil(t, gotReader) _, isCloser := gotReader.(io.ReadCloser) @@ -52,7 +53,7 @@ func TestNewReader_Closer_type(t *testing.T) { gotReader) bufCloser := io.NopCloser(buf) - gotReader = progress.NewReader(ctx, "closer", bufCloser) + gotReader = progress.NewReader(ctx, "closer", -1, bufCloser) require.NotNil(t, gotReader) _, isCloser = gotReader.(io.ReadCloser) diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index 143002c9c..b87616dba 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -42,7 +42,9 @@ import ( // writer does not. This is necessary because we need a means of stopping the // progress bar when writing is complete. If the underlying writer does // implement [io.Closer], it will be closed when the returned writer is closed. -func NewWriter(ctx context.Context, msg string, w io.Writer) io.Writer { +// +// If size is unknown, set to -1. +func NewWriter(ctx context.Context, msg string, size int64, w io.Writer) io.Writer { if w, ok := w.(*progCopier); ok && ctx == w.ctx { return w } @@ -52,7 +54,7 @@ func NewWriter(ctx context.Context, msg string, w io.Writer) io.Writer { return contextio.NewWriter(ctx, w) } - spinner := pb.NewByteCounterSpinner(msg) + spinner := pb.NewByteCounterSpinner(msg, size) return &progCopier{progWriter{ ctx: ctx, w: spinner.bar.ProxyWriter(w), @@ -123,7 +125,7 @@ func (w *progWriter) Close() error { // reader does not. This is necessary because we need a means of stopping the // progress bar when writing is complete. If the underlying reader does // implement [io.Closer], it will be closed when the returned reader is closed. -func NewReader(ctx context.Context, msg string, r io.Reader) io.Reader { +func NewReader(ctx context.Context, msg string, size int64, r io.Reader) io.Reader { if r, ok := r.(*progReader); ok && ctx == r.ctx { return r } @@ -133,7 +135,7 @@ func NewReader(ctx context.Context, msg string, r io.Reader) io.Reader { return contextio.NewReader(ctx, r) } - spinner := pb.NewByteCounterSpinner(msg) + spinner := pb.NewByteCounterSpinner(msg, size) pr := &progReader{ ctx: ctx, r: spinner.bar.ProxyReader(r), diff --git a/libsq/dbwriter.go b/libsq/dbwriter.go index 9d1ca3781..c01f9c473 100644 --- a/libsq/dbwriter.go +++ b/libsq/dbwriter.go @@ -16,6 +16,10 @@ import ( "github.com/neilotoole/sq/libsq/source" ) +// MsgIngestRecords is the typical message used with [libsq.NewDBWriter] +// to indicate that records are being ingested. +const MsgIngestRecords = "Ingesting records" + // DBWriter implements RecordWriter, writing // records to a database table. type DBWriter struct { @@ -75,7 +79,8 @@ func DBWriterCreateTableIfNotExistsHook(destTblName string) DBWriterPreWriteHook // in destPool. The recChSize param controls the size of recordCh // returned by the writer's Open method. func NewDBWriter(msg string, destPool driver.Pool, destTbl string, recChSize int, - preWriteHooks ...DBWriterPreWriteHook) *DBWriter { + preWriteHooks ...DBWriterPreWriteHook, +) *DBWriter { return &DBWriter{ msg: msg, destPool: destPool, @@ -118,7 +123,15 @@ func (w *DBWriter) Open(ctx context.Context, cancelFn context.CancelFunc, recMet } batchSize := driver.MaxBatchRows(w.destPool.SQLDriver(), len(recMeta.Names())) - w.bi, err = driver.NewBatchInsert(ctx, w.msg, w.destPool.SQLDriver(), tx, w.destTbl, recMeta.Names(), batchSize) + w.bi, err = driver.NewBatchInsert( + ctx, + w.msg, + w.destPool.SQLDriver(), + tx, + w.destTbl, + recMeta.Names(), + batchSize, + ) if err != nil { w.rollback(ctx, tx, err) return nil, nil, err diff --git a/libsq/driver/driver_test.go b/libsq/driver/driver_test.go index 14a9cc315..49511277c 100644 --- a/libsq/driver/driver_test.go +++ b/libsq/driver/driver_test.go @@ -294,7 +294,15 @@ func TestNewBatchInsert(t *testing.T) { // Get records from TblActor that we'll write to the new tbl recMeta, recs := testh.RecordsFromTbl(t, handle, sakila.TblActor) - bi, err := driver.NewBatchInsert(th.Context, "Insert records", drvr, conn, tblName, recMeta.Names(), batchSize) + bi, err := driver.NewBatchInsert( + th.Context, + "Insert records", + drvr, + conn, + tblName, + recMeta.Names(), + batchSize, + ) require.NoError(t, err) for _, rec := range recs { diff --git a/libsq/driver/record.go b/libsq/driver/record.go index 01c7929b9..6e1dafeba 100644 --- a/libsq/driver/record.go +++ b/libsq/driver/record.go @@ -5,7 +5,6 @@ import ( "context" "database/sql" "fmt" - "github.com/neilotoole/sq/libsq/core/progress" "math" "reflect" "strings" @@ -21,6 +20,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/loz" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/core/stringz" @@ -383,7 +383,8 @@ func (bi *BatchInsert) Munge(rec []any) error { // //nolint:gocognit func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, - destTbl string, destColNames []string, batchSize int) (*BatchInsert, error) { + destTbl string, destColNames []string, batchSize int, +) (*BatchInsert, error) { log := lg.FromContext(ctx) if err := sqlz.RequireSingleConn(db); err != nil { @@ -426,7 +427,6 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, // is the primary concern. lg.WarnIfError(log, lgm.CloseDBStmt, errz.Err(inserter.Close())) } - } if err != nil { diff --git a/libsq/source/files.go b/libsq/source/files.go index 1e82e0d6a..4683c8fd4 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -11,8 +11,6 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/h2non/filetype" "github.com/h2non/filetype/matchers" "golang.org/x/sync/errgroup" @@ -21,6 +19,7 @@ import ( "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" @@ -143,27 +142,37 @@ func (fs *Files) TypeStdin(ctx context.Context) (drivertype.Type, error) { func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { log := lg.FromContext(ctx) log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) + r, w, err := fs.fcache.Get(key) if err != nil { return nil, errz.Err(err) } - if w == nil { lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) return nil, errz.Errorf("failed to add to fscache (possibly previously added): %s", key) } + + fi, err := f.Stat() + if err != nil { + return nil, errz.Err(err) + } + size := fi.Size() + if size == 0 { + size = -1 + } + // TODO: Problematically, we copy the entire contents of f into fscache. // This is probably necessary for piped data on stdin, but for files // that already exist on the file system, it would be nice if the cacheFile // could be mapped directly to the filesystem file. This might require // hacking on the fscache impl. copier := fscache.AsyncFiller{ - Message: "Cache fill", + Message: "Reading file", Log: log.With(lga.Action, "Cache fill"), NewContextWriterFn: progress.NewWriter, // We don't use progress.NewReader here, because that // would result in double counting of bytes transferred. - NewContextReaderFn: func(ctx context.Context, msg string, r io.Reader) io.Reader { + NewContextReaderFn: func(ctx context.Context, msg string, size int64, r io.Reader) io.Reader { return contextio.NewReader(ctx, r) }, CloseReader: true, @@ -172,7 +181,7 @@ func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.R // FIXME: Added a delay for testing. Remove this before release. df := ioz.DelayReader(f, time.Millisecond, true) // if err = copier.Copy(ctx, w, f); err != nil { - if err = copier.Copy(ctx, w, df); err != nil { + if err = copier.Copy(ctx, size, w, df); err != nil { lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) return nil, errz.Err(err) } diff --git a/testh/testh.go b/testh/testh.go index f5b1d4b15..4522708f9 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -485,7 +485,15 @@ func (h *Helper) Insert(src *source.Source, tbl string, cols []string, records . defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, conn) batchSize := driver.MaxBatchRows(drvr, len(cols)) - bi, err := driver.NewBatchInsert(h.Context, "Insert records", drvr, conn, tbl, cols, batchSize) + bi, err := driver.NewBatchInsert( + h.Context, + libsq.MsgIngestRecords, + drvr, + conn, + tbl, + cols, + batchSize, + ) require.NoError(h.T, err) for _, rec := range records { From 356fa6c8087bfdaa09808507a2a6b0d8d7773a4f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 28 Nov 2023 00:09:56 -0700 Subject: [PATCH 024/195] wip --- cli/cmd_root.go | 5 +++++ cli/flag/flag.go | 3 +++ cli/options.go | 8 ++++++-- cli/output.go | 21 ++++++++++++++++++--- libsq/core/progress/progress.go | 16 ++++++++-------- libsq/source/files.go | 2 +- 6 files changed, 41 insertions(+), 14 deletions(-) diff --git a/cli/cmd_root.go b/cli/cmd_root.go index 7d5e48a38..49480b060 100644 --- a/cli/cmd_root.go +++ b/cli/cmd_root.go @@ -98,6 +98,11 @@ See docs and more: https://sq.io`, cmd.Flags().Bool(flag.Version, false, flag.VersionUsage) cmd.PersistentFlags().BoolP(flag.Monochrome, flag.MonochromeShort, false, flag.MonochromeUsage) + + addOptionFlag(cmd.PersistentFlags(), OptProgress) + + // TODO: Move the rest of the option flags over to addOptionFlag + //cmd.PersistentFlags().Bool(flag.NoProgress, false, flag.NoProgressUsage) cmd.PersistentFlags().BoolP(flag.Verbose, flag.VerboseShort, false, flag.VerboseUsage) cmd.PersistentFlags().String(flag.Config, "", flag.ConfigUsage) diff --git a/cli/flag/flag.go b/cli/flag/flag.go index 89f96b06d..e0a895021 100644 --- a/cli/flag/flag.go +++ b/cli/flag/flag.go @@ -67,6 +67,9 @@ const ( MonochromeShort = "M" MonochromeUsage = "Don't colorize output" + NoProgress = "no-progress" + NoProgressUsage = "Don't show progress bar" + Output = "output" OutputShort = "o" OutputUsage = "Write output to instead of stdout" diff --git a/cli/options.go b/cli/options.go index 038f87614..90f0a60dc 100644 --- a/cli/options.go +++ b/cli/options.go @@ -237,12 +237,16 @@ func addOptionFlag(flags *pflag.FlagSet, opt options.Opt) (key string) { flags.IntP(key, string(opt.Short()), opt.Default(), opt.Usage()) return key case options.Bool: + defVal := opt.Default() + if opt.FlagInverted() { + defVal = !defVal + } if opt.Short() == 0 { - flags.Bool(key, opt.Default(), opt.Usage()) + flags.Bool(key, defVal, opt.Usage()) return key } - flags.BoolP(key, string(opt.Short()), opt.Default(), opt.Usage()) + flags.BoolP(key, string(opt.Short()), defVal, opt.Usage()) return key case options.Duration: if opt.Short() == 0 { diff --git a/cli/output.go b/cli/output.go index 810dbfd9b..79d220ea4 100644 --- a/cli/output.go +++ b/cli/output.go @@ -369,6 +369,11 @@ func getRecordWriterFunc(f format.Format) output.NewRecordWriterFunc { // may be decorated for dealing with color, etc. // The supplied opts must already have flags merged into it // via getOptionsFromCmd. +// +// Be cautious making changes to getPrinting. This function must +// be absolutely bulletproof, as it's called by all commands, as well +// as by the error handling mechanism. So, be sure to always check +// for nil cmd, nil cmd.Context, etc. func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer, ) (pr *output.Printing, out2, errOut2 io.Writer) { pr = output.NewPrinting() @@ -410,6 +415,17 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer pr.EnableColor(false) out2 = out errOut2 = errOut + + if cmd != nil && cmd.Context() != nil && OptProgress.Get(opts) && isTerminal(errOut) { + progColors := progress.DefaultColors() + progColors.EnableColor(false) + ctx := cmd.Context() + renderDelay := OptProgressDelay.Get(opts) + prog := progress.New(ctx, errOut, renderDelay, progColors) + out2 = progress.ShutdownOnWriteTo(prog, out2) + cmd.SetContext(progress.NewContext(ctx, prog)) + } + return pr, out2, errOut2 } @@ -436,16 +452,15 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer errOut2 = colorable.NewNonColorable(errOut) } - if OptProgress.Get(opts) && isTerminal(errOut) { + if cmd != nil && cmd.Context() != nil && OptProgress.Get(opts) && isTerminal(errOut) { progColors := progress.DefaultColors() progColors.EnableColor(isColorTerminal(errOut)) ctx := cmd.Context() renderDelay := OptProgressDelay.Get(opts) - prog := progress.New(ctx, errOut, renderDelay, progColors) + prog := progress.New(ctx, errOut2, renderDelay, progColors) out2 = progress.ShutdownOnWriteTo(prog, out2) cmd.SetContext(progress.NewContext(ctx, prog)) - logFrom(cmd).Debug("Initialized progress") } logFrom(cmd).Debug("Constructed output.Printing", lga.Val, pr) diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 0c86f4b15..e537db33b 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -123,24 +123,24 @@ func ShutdownOnWriteTo(p *Progress, w io.Writer) io.Writer { if p == nil { return w } - return &writeNotifier{ - p: p, - w: w, + return &WriteNotifyOnce{ + fn: p.Wait, + w: w, } } -var _ io.Writer = (*writeNotifier)(nil) +var _ io.Writer = (*WriteNotifyOnce)(nil) -type writeNotifier struct { - p *Progress +type WriteNotifyOnce struct { + fn func() w io.Writer notifyOnce sync.Once } // Write implements [io.Writer]. On first invocation, the referenced // progress.Progress is stopped via its [Progress.Wait] method. -func (w *writeNotifier) Write(p []byte) (n int, err error) { - w.notifyOnce.Do(w.p.Wait) +func (w *WriteNotifyOnce) Write(p []byte) (n int, err error) { + w.notifyOnce.Do(w.fn) return w.w.Write(p) } diff --git a/libsq/source/files.go b/libsq/source/files.go index 4683c8fd4..347eb0758 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -167,7 +167,7 @@ func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.R // could be mapped directly to the filesystem file. This might require // hacking on the fscache impl. copier := fscache.AsyncFiller{ - Message: "Reading file", + Message: "Reading source data", Log: log.With(lga.Action, "Cache fill"), NewContextWriterFn: progress.NewWriter, // We don't use progress.NewReader here, because that From 248aba1c50d89dfe592d7330b17dc2021e8814fa Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 28 Nov 2023 00:38:44 -0700 Subject: [PATCH 025/195] wip --- cli/cmd_root.go | 2 -- cli/output.go | 9 +++-- libsq/core/ioz/ioz.go | 30 ++++++++++++++++ libsq/core/progress/progress.go | 63 +++------------------------------ libsq/source/files.go | 6 +--- 5 files changed, 42 insertions(+), 68 deletions(-) diff --git a/cli/cmd_root.go b/cli/cmd_root.go index 49480b060..90a4de350 100644 --- a/cli/cmd_root.go +++ b/cli/cmd_root.go @@ -100,9 +100,7 @@ See docs and more: https://sq.io`, cmd.PersistentFlags().BoolP(flag.Monochrome, flag.MonochromeShort, false, flag.MonochromeUsage) addOptionFlag(cmd.PersistentFlags(), OptProgress) - // TODO: Move the rest of the option flags over to addOptionFlag - //cmd.PersistentFlags().Bool(flag.NoProgress, false, flag.NoProgressUsage) cmd.PersistentFlags().BoolP(flag.Verbose, flag.VerboseShort, false, flag.VerboseUsage) cmd.PersistentFlags().String(flag.Config, "", flag.ConfigUsage) diff --git a/cli/output.go b/cli/output.go index 79d220ea4..f96f940d8 100644 --- a/cli/output.go +++ b/cli/output.go @@ -7,6 +7,8 @@ import ( "strings" "time" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/fatih/color" colorable "github.com/mattn/go-colorable" wordwrap "github.com/mitchellh/go-wordwrap" @@ -422,7 +424,8 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer ctx := cmd.Context() renderDelay := OptProgressDelay.Get(opts) prog := progress.New(ctx, errOut, renderDelay, progColors) - out2 = progress.ShutdownOnWriteTo(prog, out2) + // On first write to stdout, we remove the progress widget. + out2 = ioz.NotifyOnceWriter(out2, prog.Wait) cmd.SetContext(progress.NewContext(ctx, prog)) } @@ -459,7 +462,9 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer ctx := cmd.Context() renderDelay := OptProgressDelay.Get(opts) prog := progress.New(ctx, errOut2, renderDelay, progColors) - out2 = progress.ShutdownOnWriteTo(prog, out2) + + // On first write to stdout, we remove the progress widget. + out2 = ioz.NotifyOnceWriter(out2, prog.Wait) cmd.SetContext(progress.NewContext(ctx, prog)) } diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index cb97e208f..d407d8c34 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "strings" + "sync" "time" yaml "github.com/goccy/go-yaml" @@ -242,3 +243,32 @@ func (d delayReadCloser) Close() error { func LimitRandReader(limit int64) io.Reader { return io.LimitReader(crand.Reader, limit) } + +// NotifyOnceWriter returns an io.Writer that invokes fn on the first +// invocation of Write. If w or fn is nil, w is returned. +func NotifyOnceWriter(w io.Writer, fn func()) io.Writer { + if w == nil || fn == nil { + return w + } + + return ¬ifyOnceWriter{ + fn: fn, + w: w, + } +} + +var _ io.Writer = (*notifyOnceWriter)(nil) + +type notifyOnceWriter struct { + fn func() + w io.Writer + notifyOnce sync.Once +} + +// Write implements [io.Writer]. On the first invocation of this +// method, fn is invoked. +func (w *notifyOnceWriter) Write(p []byte) (n int, err error) { + w.notifyOnce.Do(w.fn) + + return w.w.Write(p) +} diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index e537db33b..1a39e6320 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -2,9 +2,7 @@ package progress import ( "context" - "fmt" "io" - "sync" "time" humanize "github.com/dustin/go-humanize" @@ -14,7 +12,6 @@ import ( "github.com/vbauerster/mpb/v8/decor" "github.com/neilotoole/sq/libsq/core/cleanup" - "github.com/neilotoole/sq/libsq/core/stringz" ) type runKey struct{} @@ -79,9 +76,9 @@ func (c *Colors) EnableColor(enable bool) { c.Size.DisableColor() } +// Progress represents a container that renders one or more progress bars. type Progress struct { p *mpb.Progress - mu *sync.Mutex // FIXME: we don't need mu? colors *Colors cleanup *cleanup.Cleanup } @@ -89,9 +86,7 @@ type Progress struct { // Wait waits for all bars to complete and finally shutdowns container. After // this method has been called, there is no way to reuse `*Progress` instance. func (p *Progress) Wait() { - p.mu.Lock() - defer p.mu.Unlock() - + // Invoking cleanup will call Stop on all the bars. _ = p.cleanup.Run() p.p.Wait() } @@ -107,50 +102,7 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors colors = DefaultColors() } - return &Progress{p: p, colors: colors, mu: &sync.Mutex{}, cleanup: cleanup.New()} -} - -// ShutdownOnWriteTo returns a writer decorator that stop the -// progress.Progress when w.Write is called. Typically p writes -// to stderr, but stdout is passed to this method. That is, when -// the program starts writing to stdout, we want to shut down -// and remove the progress bar. -// -// REVISIT: ShutdownOnWriteTo is not a great name. -func ShutdownOnWriteTo(p *Progress, w io.Writer) io.Writer { - // REVISIT: Should we check if w implements other io interfaces, - // such as io.WriterAt etc? Or do we really only care about io.Writer? - if p == nil { - return w - } - return &WriteNotifyOnce{ - fn: p.Wait, - w: w, - } -} - -var _ io.Writer = (*WriteNotifyOnce)(nil) - -type WriteNotifyOnce struct { - fn func() - w io.Writer - notifyOnce sync.Once -} - -// Write implements [io.Writer]. On first invocation, the referenced -// progress.Progress is stopped via its [Progress.Wait] method. -func (w *WriteNotifyOnce) Write(p []byte) (n int, err error) { - w.notifyOnce.Do(w.fn) - - return w.w.Write(p) -} - -func normalizeMsgLength(msg string, length int) string { - if len(msg) > length { - msg = stringz.TrimLenMiddle(msg, length) - } - - return fmt.Sprintf("%-*s", length, msg) + return &Progress{p: p, colors: colors, cleanup: cleanup.New()} } // NewCountSpinner returns a new indeterminate spinner bar whose label @@ -164,9 +116,6 @@ func (p *Progress) NewCountSpinner(msg, unit string) *Spinner { return nil } - p.mu.Lock() - defer p.mu.Unlock() - decorator := decor.Any(func(statistics decor.Statistics) string { s := humanize.Comma(statistics.Current) if unit != "" { @@ -190,14 +139,10 @@ func (p *Progress) NewByteCounterSpinner(msg string, size int64) *Spinner { return nil } - p.mu.Lock() - defer p.mu.Unlock() - var decorator decor.Decorator if size < 0 { - decorator = decor.CountersNoUnit("% .2f / ?") + decorator = decor.Current(decor.SizeB1024(0), "% .1f") } else { - // decorator = decor.Current(decor.SizeB1024(0), "% .1f") decorator = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") } decorator = ColorMeta(decorator, p.colors.Size) diff --git a/libsq/source/files.go b/libsq/source/files.go index 347eb0758..50626cc14 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -19,7 +19,6 @@ import ( "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" @@ -178,10 +177,7 @@ func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.R CloseReader: true, } - // FIXME: Added a delay for testing. Remove this before release. - df := ioz.DelayReader(f, time.Millisecond, true) - // if err = copier.Copy(ctx, w, f); err != nil { - if err = copier.Copy(ctx, size, w, df); err != nil { + if err = copier.Copy(ctx, size, w, f); err != nil { lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) return nil, errz.Err(err) } From 4de09309d12a69ad2aa2d63c95f88c664df9c9c8 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 28 Nov 2023 09:09:34 -0700 Subject: [PATCH 026/195] wip --- libsq/core/ioz/contextio/contextio.go | 117 ++++++++++++++++++- libsq/core/ioz/contextio/contextio_test.go | 65 +++++++++++ libsq/core/ioz/ioz.go | 36 ++++++ libsq/core/progress/progress.go | 125 ++++++++++++++------- libsq/core/progress/progress_test.go | 45 ++++++-- libsq/core/progress/progressio.go | 4 +- libsq/driver/record.go | 8 +- libsq/source/files.go | 14 ++- 8 files changed, 346 insertions(+), 68 deletions(-) create mode 100644 libsq/core/ioz/contextio/contextio_test.go diff --git a/libsq/core/ioz/contextio/contextio.go b/libsq/core/ioz/contextio/contextio.go index b0ae729f5..535db2489 100644 --- a/libsq/core/ioz/contextio/contextio.go +++ b/libsq/core/ioz/contextio/contextio.go @@ -20,29 +20,61 @@ package contextio import ( "context" + "errors" "io" + + "github.com/neilotoole/sq/libsq/core/errz" ) +var _ io.Writer = (*writer)(nil) + type writer struct { ctx context.Context w io.Writer } +var _ io.WriteCloser = (*writeCloser)(nil) + +type writeCloser struct { + writer +} + +var _ io.ReaderFrom = (*copier)(nil) + type copier struct { writer } +var _ io.WriteCloser = (*copyCloser)(nil) + +type copyCloser struct { + writeCloser +} + // NewWriter wraps an [io.Writer] to handle context cancellation. // // Context state is checked BEFORE every Write. // // The returned Writer also implements [io.ReaderFrom] to allow [io.Copy] to select // the best strategy while still checking the context state before every chunk transfer. +// +// If w implements io.WriteCloser, the returned Writer will +// also implement io.WriteCloser. func NewWriter(ctx context.Context, w io.Writer) io.Writer { if w, ok := w.(*copier); ok && ctx == w.ctx { return w } - return &copier{writer{ctx: ctx, w: w}} + + if w, ok := w.(*copyCloser); ok && ctx == w.ctx { + return w + } + + wr := writer{ctx: ctx, w: w} + if _, ok := w.(io.Closer); ok { + return ©Closer{writeCloser: writeCloser{writer: wr}} + } + + return &copier{writer: wr} } // Write implements [io.Writer], but with context awareness. @@ -55,6 +87,30 @@ func (w *writer) Write(p []byte) (n int, err error) { } } +// Close implements [io.Closer], but with context awareness. +func (w *writeCloser) Close() error { + var closeErr error + if c, ok := w.w.(io.Closer); ok { + closeErr = c.Close() + } + + select { + case <-w.ctx.Done(): + ctxErr := w.ctx.Err() + switch { + case closeErr == nil, + errz.IsErrContext(closeErr): + return ctxErr + default: + return errors.Join(ctxErr, closeErr) + } + default: + return closeErr + } +} + +var _ io.Reader = (*reader)(nil) + type reader struct { ctx context.Context r io.Reader @@ -63,13 +119,27 @@ type reader struct { // NewReader wraps an [io.Reader] to handle context cancellation. // // Context state is checked BEFORE every Read. +// +// If r implements io.ReadCloser, the returned reader will +// also implement io.ReadCloser. func NewReader(ctx context.Context, r io.Reader) io.Reader { if r, ok := r.(*reader); ok && ctx == r.ctx { return r } - return &reader{ctx: ctx, r: r} + + if r, ok := r.(*readCloser); ok && ctx == r.ctx { + return r + } + + rdr := reader{ctx: ctx, r: r} + if _, ok := r.(io.ReadCloser); ok { + return &readCloser{rdr} + } + + return &rdr } +// Read implements [io.Reader], but with context awareness. func (r *reader) Read(p []byte) (n int, err error) { select { case <-r.ctx.Done(): @@ -79,6 +149,34 @@ func (r *reader) Read(p []byte) (n int, err error) { } } +var _ io.ReadCloser = (*readCloser)(nil) + +type readCloser struct { + reader +} + +// Close implements [io.Closer], but with context awareness. +func (rc *readCloser) Close() error { + var closeErr error + if c, ok := rc.r.(io.Closer); ok { + closeErr = c.Close() + } + + select { + case <-rc.ctx.Done(): + ctxErr := rc.ctx.Err() + switch { + case closeErr == nil, + errz.IsErrContext(closeErr): + return ctxErr + default: + return errors.Join(ctxErr, closeErr) + } + default: + return closeErr + } +} + // ReadFrom implements interface [io.ReaderFrom], but with context awareness. // // This should allow efficient copying allowing writer or reader to define the chunk size. @@ -99,7 +197,7 @@ func (w *copier) ReadFrom(r io.Reader) (n int64, err error) { // NewCloser wraps an [io.Reader] to handle context cancellation. // -// Context state is checked BEFORE any Close. +// The underlying io.Closer is closed even if the context is done. func NewCloser(ctx context.Context, c io.Closer) io.Closer { return &closer{ctx: ctx, c: c} } @@ -110,10 +208,19 @@ type closer struct { } func (c *closer) Close() error { + closeErr := c.c.Close() + select { case <-c.ctx.Done(): - return c.ctx.Err() + ctxErr := c.ctx.Err() + switch { + case closeErr == nil, + errz.IsErrContext(closeErr): + return ctxErr + default: + return errors.Join(ctxErr, closeErr) + } default: - return c.c.Close() + return closeErr } } diff --git a/libsq/core/ioz/contextio/contextio_test.go b/libsq/core/ioz/contextio/contextio_test.go new file mode 100644 index 000000000..88d25801e --- /dev/null +++ b/libsq/core/ioz/contextio/contextio_test.go @@ -0,0 +1,65 @@ +package contextio_test + +import ( + "bytes" + "context" + "io" + "testing" + + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestNewWriter_Closer tests that the returned writer +// implements io.WriteCloser, or not, depending upon the type of +// the underlying writer. +func TestNewWriter_Closer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // bytes.Buffer doesn't implement io.Closer + buf := &bytes.Buffer{} + gotWriter := contextio.NewWriter(ctx, buf) + require.NotNil(t, gotWriter) + _, isCloser := gotWriter.(io.WriteCloser) + + assert.False(t, isCloser, "expected reader NOT to be io.WriteCloser, but was %T", + gotWriter) + + bufCloser := ioz.ToWriteCloser(buf) + gotWriter = contextio.NewWriter(ctx, bufCloser) + require.NotNil(t, gotWriter) + _, isCloser = gotWriter.(io.WriteCloser) + + assert.True(t, isCloser, "expected reader to implement io.WriteCloser, but was %T", + gotWriter) +} + +// TestNewReader_Closer tests that the returned reader +// implements io.ReadCloser, or not, depending upon the type of +// the underlying writer. +func TestNewReader_Closer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // bytes.Buffer doesn't implement io.Closer + buf := &bytes.Buffer{} + gotReader := contextio.NewReader(ctx, buf) + require.NotNil(t, gotReader) + _, isCloser := gotReader.(io.ReadCloser) + + assert.False(t, isCloser, "expected reader NOT to be io.ReadCloser but was %T", + gotReader) + + bufCloser := io.NopCloser(buf) + gotReader = contextio.NewReader(ctx, bufCloser) + require.NotNil(t, gotReader) + _, isCloser = gotReader.(io.ReadCloser) + + assert.True(t, isCloser, "expected reader to be io.ReadCloser but was %T", + gotReader) +} diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index d407d8c34..663a7b3dc 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -272,3 +272,39 @@ func (w *notifyOnceWriter) Write(p []byte) (n int, err error) { return w.w.Write(p) } + +// ToWriteCloser returns w as an io.WriteCloser. If w implements +// io.WriteCloser, w is returned. Otherwise, w is wrapped in a +// no-op decorator that implements io.WriteCloser. +// +// ToWriteCloser is the missing sibling of io.NopCloser, which +// isn't implemented in stdlib. See: https://github.com/golang/go/issues/22823. +func ToWriteCloser(w io.Writer) io.WriteCloser { + if wc, ok := w.(io.WriteCloser); ok { + return wc + } + return toNopWriteCloser(w) +} + +func toNopWriteCloser(w io.Writer) io.WriteCloser { + if _, ok := w.(io.ReaderFrom); ok { + return nopWriteCloserReaderFrom{w} + } + return nopWriteCloser{w} +} + +type nopWriteCloser struct { + io.Writer +} + +func (nopWriteCloser) Close() error { return nil } + +type nopWriteCloserReaderFrom struct { + io.Writer +} + +func (nopWriteCloserReaderFrom) Close() error { return nil } + +func (c nopWriteCloserReaderFrom) ReadFrom(r io.Reader) (int64, error) { + return c.Writer.(io.ReaderFrom).ReadFrom(r) +} diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 1a39e6320..bad6da61a 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -50,6 +50,9 @@ func DefaultColors() *Colors { Message: color.New(color.Faint), Spinner: color.New(color.FgGreen, color.Bold), Size: color.New(color.Faint), + // Percent: color.New(color.FgHiBlue), + Percent: color.New(color.FgCyan, color.Faint), + // Percent: color.New(color.FgCyan), } } @@ -57,6 +60,7 @@ type Colors struct { Message *color.Color Spinner *color.Color Size *color.Color + Percent *color.Color } func (c *Colors) EnableColor(enable bool) { @@ -68,33 +72,28 @@ func (c *Colors) EnableColor(enable bool) { c.Message.EnableColor() c.Spinner.EnableColor() c.Size.EnableColor() + c.Percent.EnableColor() return } c.Message.DisableColor() c.Spinner.DisableColor() c.Size.DisableColor() + c.Percent.EnableColor() } -// Progress represents a container that renders one or more progress bars. -type Progress struct { - p *mpb.Progress - colors *Colors - cleanup *cleanup.Cleanup -} - -// Wait waits for all bars to complete and finally shutdowns container. After -// this method has been called, there is no way to reuse `*Progress` instance. -func (p *Progress) Wait() { - // Invoking cleanup will call Stop on all the bars. - _ = p.cleanup.Run() - p.p.Wait() -} +const ( + barWidth = 28 + boxWidth = 64 +) +// New returns a new Progress instance, which is a container for progress bars. +// The returned Progress instance is safe for concurrent use. The caller is +// responsible for calling [Progress.Wait] on the returned Progress. func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors) *Progress { p := mpb.NewWithContext(ctx, mpb.WithOutput(out), - mpb.WithWidth(64), + mpb.WithWidth(boxWidth), mpb.WithRenderDelay(renderDelay(delay)), ) @@ -105,13 +104,45 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors return &Progress{p: p, colors: colors, cleanup: cleanup.New()} } -// NewCountSpinner returns a new indeterminate spinner bar whose label -// metric is the provided unit. The caller is ultimately -// responsible for calling [Spinner.Stop] on the returned Spinner. However, -// the returned Spinner is also added to the Progress's cleanup list, so +// Progress represents a container that renders one or more progress bars. +// The caller is responsible for calling [Progress.Wait] to indicate +// completion. +type Progress struct { + p *mpb.Progress + colors *Colors + cleanup *cleanup.Cleanup +} + +// Wait waits for all bars to complete and finally shuts down the +// container. After this method has been called, there is no way +// to reuse the Progress instance. +func (p *Progress) Wait() { + // Invoking cleanup will call Bar.Stop on all the bars. + _ = p.cleanup.Run() + p.p.Wait() +} + +// NewUnitCounter returns a new indeterminate bar whose label +// metric is the plural of the provided unit. The caller is ultimately +// responsible for calling [Bar.Stop] on the returned Bar. However, +// the returned Bar is also added to the Progress's cleanup list, so // it will be called automatically when the Progress is shut down, but that // may be later than the actual conclusion of the spinner's work. -func (p *Progress) NewCountSpinner(msg, unit string) *Spinner { +// +// pbar := p.NewUnitCounter("Ingest records", "record") +// defer pbar.Stop() +// +// for i := 0; i < 100; i++ { +// pbar.IncrBy(1) +// time.Sleep(100 * time.Millisecond) +// } +// +// This produces output similar to: +// +// Ingesting records ∙∙● 87 records +// +// Note that the unit arg is pluralized. +func (p *Progress) NewUnitCounter(msg, unit string) *Bar { if p == nil { return nil } @@ -125,38 +156,38 @@ func (p *Progress) NewCountSpinner(msg, unit string) *Spinner { }) decorator = ColorMeta(decorator, p.colors.Size) - return p.newSpinner(msg, -1, decorator) + return p.newBar(msg, -1, decorator) } // NewByteCounterSpinner returns a new spinner bar whose metric is the count // of bytes processed. If the size is unknown, set arg size to -1. The caller -// is ultimately responsible for calling [Spinner.Stop] on the returned Spinner. -// However, the returned Spinner is also added to the Progress's cleanup list, +// is ultimately responsible for calling [Bar.Stop] on the returned Bar. +// However, the returned Bar is also added to the Progress's cleanup list, // so it will be called automatically when the Progress is shut down, but that // may be later than the actual conclusion of the spinner's work. -func (p *Progress) NewByteCounterSpinner(msg string, size int64) *Spinner { +func (p *Progress) NewByteCounterSpinner(msg string, size int64) *Bar { if p == nil { return nil } - var decorator decor.Decorator + var counter decor.Decorator if size < 0 { - decorator = decor.Current(decor.SizeB1024(0), "% .1f") + counter = decor.Current(decor.SizeB1024(0), "% .1f") } else { - decorator = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") + counter = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") } - decorator = ColorMeta(decorator, p.colors.Size) + counter = ColorMeta(counter, p.colors.Size) + percent := decor.NewPercentage(" %.1f", decor.WCSyncSpace) + percent = ColorMeta(percent, p.colors.Percent) - return p.newSpinner(msg, size, decorator) + return p.newBar(msg, size, counter, percent) } -func (p *Progress) newSpinner(msg string, total int64, decorators ...decor.Decorator) *Spinner { +func (p *Progress) newBar(msg string, total int64, decorators ...decor.Decorator) *Bar { if p == nil { return nil } - const barWidth = 28 - if total < 0 { total = 0 } @@ -176,29 +207,37 @@ func (p *Progress) newSpinner(msg string, total int64, decorators ...decor.Decor mpb.BarRemoveOnComplete(), ) - spinner := &Spinner{bar: bar} - p.cleanup.Add(spinner.Stop) - return spinner + b := &Bar{bar: bar} + p.cleanup.Add(b.Stop) + return b } -type Spinner struct { +// Bar represents a single progress bar. The caller should invoke +// [Bar.IncrBy] as necessary to increment the bar's progress. When +// the bar is complete, the caller should invoke [Bar.Stop]. All +// methods are safe to call on a nil Bar. +type Bar struct { bar *mpb.Bar } -func (sp *Spinner) IncrBy(n int) { - if sp == nil { +// IncrBy increments progress by amount of n. It is safe to +// call IncrBy on a nil Bar. +func (b *Bar) IncrBy(n int) { + if b == nil { return } - sp.bar.IncrBy(n) + b.bar.IncrBy(n) } -func (sp *Spinner) Stop() { - if sp == nil { +// Stop stops and removes the bar. It is safe to call Stop on a nil Bar, +// or to call Stop multiple times. +func (b *Bar) Stop() { + if b == nil { return } - sp.bar.SetTotal(-1, true) - sp.bar.Wait() + b.bar.SetTotal(-1, true) + b.bar.Wait() } func renderDelay(d time.Duration) <-chan struct{} { diff --git a/libsq/core/progress/progress_test.go b/libsq/core/progress/progress_test.go index 4be4bca57..f1dc8ef12 100644 --- a/libsq/core/progress/progress_test.go +++ b/libsq/core/progress/progress_test.go @@ -16,10 +16,12 @@ import ( ) func TestNewWriter(t *testing.T) { + t.Parallel() + const limit = 1000000 ctx := context.Background() - pb := progress.New(ctx, os.Stdout, time.Millisecond, progress.DefaultColors()) + pb := progress.New(ctx, io.Discard, time.Millisecond, progress.DefaultColors()) ctx = progress.NewContext(ctx, pb) src := ioz.LimitRandReader(limit) @@ -34,10 +36,39 @@ func TestNewWriter(t *testing.T) { pb.Wait() } -// TestNewWriter_Closer_type tests that the returned writer -// implements io.ReadCloser, or not, depending upon the type of -// the underlying writer. -func TestNewReader_Closer_type(t *testing.T) { +// TestNewWriter_Closer tests that the returned writer +// implements io.WriteCloser regardless of whether the +// underlying writer does. +func TestNewWriter_Closer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + pb := progress.New(ctx, os.Stdout, time.Millisecond, progress.DefaultColors()) + ctx = progress.NewContext(ctx, pb) + defer pb.Wait() + + // bytes.Buffer doesn't implement io.Closer + buf := &bytes.Buffer{} + gotWriter := progress.NewWriter(ctx, "no closer", -1, buf) + require.NotNil(t, gotWriter) + _, isCloser := gotWriter.(io.WriteCloser) + assert.True(t, isCloser, "expected writer to be io.WriteCloser, but was %T", + gotWriter) + + bufCloser := ioz.ToWriteCloser(buf) + gotWriter = progress.NewWriter(ctx, "no closer", -1, bufCloser) + require.NotNil(t, gotWriter) + _, isCloser = gotWriter.(io.WriteCloser) + assert.True(t, isCloser, "expected writer to implement io.WriteCloser, but was %T", + gotWriter) +} + +// TestNewReader_Closer tests that the returned reader +// implements io.ReadCloser regardless of whether the +// underlying writer does. +func TestNewReader_Closer(t *testing.T) { + t.Parallel() + ctx := context.Background() pb := progress.New(ctx, os.Stdout, time.Millisecond, progress.DefaultColors()) ctx = progress.NewContext(ctx, pb) @@ -48,15 +79,13 @@ func TestNewReader_Closer_type(t *testing.T) { gotReader := progress.NewReader(ctx, "no closer", -1, buf) require.NotNil(t, gotReader) _, isCloser := gotReader.(io.ReadCloser) - - assert.False(t, isCloser, "expected reader NOT to be io.ReadCloser but was %T", + assert.True(t, isCloser, "expected reader to be io.ReadCloser but was %T", gotReader) bufCloser := io.NopCloser(buf) gotReader = progress.NewReader(ctx, "closer", -1, bufCloser) require.NotNil(t, gotReader) _, isCloser = gotReader.(io.ReadCloser) - assert.True(t, isCloser, "expected reader to be io.ReadCloser but was %T", gotReader) } diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index b87616dba..a85332785 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -67,7 +67,7 @@ var _ io.WriteCloser = (*progWriter)(nil) type progWriter struct { ctx context.Context w io.Writer - spinner *Spinner + spinner *Bar } // Write implements [io.Writer], but with context awareness. @@ -149,7 +149,7 @@ var _ io.ReadCloser = (*progReader)(nil) type progReader struct { ctx context.Context r io.Reader - spinner *Spinner + spinner *Bar } // Close implements [io.ReadCloser], but with context awareness. diff --git a/libsq/driver/record.go b/libsq/driver/record.go index 6e1dafeba..d2aa36bbe 100644 --- a/libsq/driver/record.go +++ b/libsq/driver/record.go @@ -391,7 +391,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, return nil, err } - spinner := progress.FromContext(ctx).NewCountSpinner(msg, "rec") + pbar := progress.FromContext(ctx).NewUnitCounter(msg, "rec") recCh := make(chan []any, batchSize*8) errCh := make(chan error, 1) @@ -414,7 +414,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, var affected int64 defer func() { - spinner.Stop() + pbar.Stop() if inserter != nil { if err == nil { @@ -467,7 +467,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, } bi.written.Add(affected) - spinner.IncrBy(int(affected)) + pbar.IncrBy(int(affected)) if rec == nil { // recCh is closed (coincidentally exactly on the @@ -510,7 +510,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, } bi.written.Add(affected) - spinner.IncrBy(int(affected)) + pbar.IncrBy(int(affected)) // We're done return diff --git a/libsq/source/files.go b/libsq/source/files.go index 50626cc14..fdd0f691b 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -11,6 +11,8 @@ import ( "sync" "time" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/h2non/filetype" "github.com/h2non/filetype/matchers" "golang.org/x/sync/errgroup" @@ -177,7 +179,8 @@ func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.R CloseReader: true, } - if err = copier.Copy(ctx, size, w, f); err != nil { + df := ioz.DelayReader(f, time.Millisecond, true) // FIXME: Delete + if err = copier.Copy(ctx, size, w, df); err != nil { lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) return nil, errz.Err(err) } @@ -283,7 +286,7 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro // openLocation returns a file for loc. It is the caller's // responsibility to close the returned file. -func (fs *Files) openLocation(_ context.Context, loc string) (*os.File, error) { +func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) { var fpath string var ok bool var err error @@ -299,8 +302,7 @@ func (fs *Files) openLocation(_ context.Context, loc string) (*os.File, error) { } // It's a remote file - // TODO: fetch should take ctx to allow for cancellation - fpath, err = fs.fetch(u.String()) + fpath, err = fs.fetch(ctx, u.String()) if err != nil { return nil, err } @@ -323,7 +325,7 @@ func (fs *Files) openFile(fpath string) (*os.File, error) { // fetch ensures that loc exists locally as a file. This may // entail downloading the file via HTTPS etc. -func (fs *Files) fetch(loc string) (fpath string, err error) { +func (fs *Files) fetch(ctx context.Context, loc string) (fpath string, err error) { // This impl is a vestigial abomination from an early // experiment. @@ -346,7 +348,7 @@ func (fs *Files) fetch(loc string) (fpath string, err error) { fetchr := &fetcher.Fetcher{} // TOOD: ultimately should be passing a real context here - err = fetchr.Fetch(context.Background(), u.String(), dlFile) + err = fetchr.Fetch(ctx, u.String(), dlFile) if err != nil { return "", errz.Err(err) } From 387782be0b4dff27f230c4b2197cf3f4bf902353 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 28 Nov 2023 13:21:28 -0700 Subject: [PATCH 027/195] progress.Progress largely done --- cli/cmd_src_test.go | 30 +++ cli/output.go | 3 +- libsq/core/ioz/contextio/contextio_test.go | 5 +- libsq/core/progress/progress.go | 210 +++++++++++++++------ libsq/source/files.go | 3 +- 5 files changed, 183 insertions(+), 68 deletions(-) create mode 100644 cli/cmd_src_test.go diff --git a/cli/cmd_src_test.go b/cli/cmd_src_test.go new file mode 100644 index 000000000..59ac08567 --- /dev/null +++ b/cli/cmd_src_test.go @@ -0,0 +1,30 @@ +package cli_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/neilotoole/sq/cli/testrun" + "github.com/neilotoole/sq/testh" + "github.com/neilotoole/sq/testh/sakila" +) + +func TestCmdSrc(t *testing.T) { + ctx := context.Background() + th := testh.New(t) + _ = th + + tr := testrun.New(ctx, t, nil).Add() + // err := tr.Exec("src") + // require.NoError(t, err) + + tr.Reset().Add(*th.Source(sakila.CSVActor)) + err := tr.Exec("src") + require.NoError(t, err) + + err = tr.Reset().Exec(".data | .[0:5]") + require.NoError(t, err) + t.Log(tr.Out.String()) +} diff --git a/cli/output.go b/cli/output.go index f96f940d8..7d1470361 100644 --- a/cli/output.go +++ b/cli/output.go @@ -7,8 +7,6 @@ import ( "strings" "time" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/fatih/color" colorable "github.com/mattn/go-colorable" wordwrap "github.com/mitchellh/go-wordwrap" @@ -27,6 +25,7 @@ import ( "github.com/neilotoole/sq/cli/output/xmlw" "github.com/neilotoole/sq/cli/output/yamlw" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/progress" diff --git a/libsq/core/ioz/contextio/contextio_test.go b/libsq/core/ioz/contextio/contextio_test.go index 88d25801e..97dbfb273 100644 --- a/libsq/core/ioz/contextio/contextio_test.go +++ b/libsq/core/ioz/contextio/contextio_test.go @@ -6,10 +6,11 @@ import ( "io" "testing" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" ) // TestNewWriter_Closer tests that the returned writer diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index bad6da61a..c7851a46a 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -3,6 +3,7 @@ package progress import ( "context" "io" + "sync" "time" humanize "github.com/dustin/go-humanize" @@ -25,7 +26,9 @@ func NewContext(ctx context.Context, prog *Progress) context.Context { return context.WithValue(ctx, runKey{}, prog) } -// FromContext extracts the Progress added to ctx via NewContext. +// FromContext returns the [Progress] added to ctx via NewContext, +// or returns nil. Note that it is safe to invoke the methods +// of a nil [Progress]. func FromContext(ctx context.Context) *Progress { if ctx == nil { return nil @@ -41,45 +44,6 @@ func FromContext(ctx context.Context) *Progress { } return nil - - // return ctx.Value(runKey{}).(*Progress) -} - -func DefaultColors() *Colors { - return &Colors{ - Message: color.New(color.Faint), - Spinner: color.New(color.FgGreen, color.Bold), - Size: color.New(color.Faint), - // Percent: color.New(color.FgHiBlue), - Percent: color.New(color.FgCyan, color.Faint), - // Percent: color.New(color.FgCyan), - } -} - -type Colors struct { - Message *color.Color - Spinner *color.Color - Size *color.Color - Percent *color.Color -} - -func (c *Colors) EnableColor(enable bool) { - if c == nil { - return - } - - if enable { - c.Message.EnableColor() - c.Spinner.EnableColor() - c.Size.EnableColor() - c.Percent.EnableColor() - return - } - - c.Message.DisableColor() - c.Spinner.DisableColor() - c.Size.DisableColor() - c.Percent.EnableColor() } const ( @@ -90,36 +54,84 @@ const ( // New returns a new Progress instance, which is a container for progress bars. // The returned Progress instance is safe for concurrent use. The caller is // responsible for calling [Progress.Wait] on the returned Progress. +// The Progress is lazily initialized, and thus the delay clock doesn't +// start ticking until the first call to one of the Progress.NewX methods. func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors) *Progress { - p := mpb.NewWithContext(ctx, - mpb.WithOutput(out), - mpb.WithWidth(boxWidth), - mpb.WithRenderDelay(renderDelay(delay)), - ) + var cancelFn context.CancelFunc + ctx, cancelFn = context.WithCancel(ctx) if colors == nil { colors = DefaultColors() } - return &Progress{p: p, colors: colors, cleanup: cleanup.New()} + prog := &Progress{ + ctx: ctx, + mu: sync.Mutex{}, + colors: colors, + cleanup: cleanup.New(), + cancelFn: cancelFn, + } + + prog.pcInit = func() { + opts := []mpb.ContainerOption{ + mpb.WithOutput(out), + mpb.WithWidth(boxWidth), + } + if delay > 0 { + opts = append(opts, mpb.WithRenderDelay(renderDelay(ctx, delay))) + } + prog.pc = mpb.NewWithContext(ctx, opts...) + prog.pcInit = nil + } + return prog } // Progress represents a container that renders one or more progress bars. // The caller is responsible for calling [Progress.Wait] to indicate // completion. type Progress struct { - p *mpb.Progress + // mu guards ALL public methods. + mu sync.Mutex + + ctx context.Context + + // pc is the underlying progress container. It is lazily initialized + // by pcInit. Any method that accesses pc must be certain that + // pcInit has been called. + pc *mpb.Progress + + // pcInit is the func that lazily initializes pc. + pcInit func() + colors *Colors cleanup *cleanup.Cleanup + + cancelFn context.CancelFunc } // Wait waits for all bars to complete and finally shuts down the // container. After this method has been called, there is no way // to reuse the Progress instance. func (p *Progress) Wait() { + if p == nil { + return + } + + p.mu.Lock() + defer p.mu.Unlock() + + if p.pc == nil { + return + } + + if p.cleanup.Len() == 0 { + return + } + + p.cancelFn() // Invoking cleanup will call Bar.Stop on all the bars. _ = p.cleanup.Run() - p.p.Wait() + p.pc.Wait() } // NewUnitCounter returns a new indeterminate bar whose label @@ -147,6 +159,9 @@ func (p *Progress) NewUnitCounter(msg, unit string) *Bar { return nil } + p.mu.Lock() + defer p.mu.Unlock() + decorator := decor.Any(func(statistics decor.Statistics) string { s := humanize.Comma(statistics.Current) if unit != "" { @@ -155,7 +170,7 @@ func (p *Progress) NewUnitCounter(msg, unit string) *Bar { return s }) - decorator = ColorMeta(decorator, p.colors.Size) + decorator = colorize(decorator, p.colors.Size) return p.newBar(msg, -1, decorator) } @@ -170,38 +185,51 @@ func (p *Progress) NewByteCounterSpinner(msg string, size int64) *Bar { return nil } + p.mu.Lock() + defer p.mu.Unlock() + var counter decor.Decorator if size < 0 { counter = decor.Current(decor.SizeB1024(0), "% .1f") } else { counter = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") } - counter = ColorMeta(counter, p.colors.Size) + counter = colorize(counter, p.colors.Size) + percent := decor.NewPercentage(" %.1f", decor.WCSyncSpace) - percent = ColorMeta(percent, p.colors.Percent) + percent = colorize(percent, p.colors.Percent) return p.newBar(msg, size, counter, percent) } +// newBar returns a new Bar. This function must only be called from +// inside the mutex. func (p *Progress) newBar(msg string, total int64, decorators ...decor.Decorator) *Bar { if p == nil { return nil } + select { + case <-p.ctx.Done(): + return nil + default: + } + + if p.pc == nil { + p.pcInit() + } + if total < 0 { total = 0 } - style := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") - style = style.Meta(func(s string) string { - return p.colors.Spinner.Sprint(s) - }) + style := spinnerStyle(p.colors.Spinner) - bar := p.p.New(total, + bar := p.pc.New(total, style, mpb.BarWidth(barWidth), mpb.PrependDecorators( - ColorMeta(decor.Name(msg, decor.WCSyncWidthR), p.colors.Message), + colorize(decor.Name(msg, decor.WCSyncWidthR), p.colors.Message), ), mpb.AppendDecorators(decorators...), mpb.BarRemoveOnComplete(), @@ -212,6 +240,16 @@ func (p *Progress) newBar(msg string, total int64, decorators ...decor.Decorator return b } +func spinnerStyle(c *color.Color) mpb.SpinnerStyleComposer { + style := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") + if c != nil { + style = style.Meta(func(s string) string { + return c.Sprint(s) + }) + } + return style +} + // Bar represents a single progress bar. The caller should invoke // [Bar.IncrBy] as necessary to increment the bar's progress. When // the bar is complete, the caller should invoke [Bar.Stop]. All @@ -237,19 +275,67 @@ func (b *Bar) Stop() { } b.bar.SetTotal(-1, true) + b.bar.Abort(true) b.bar.Wait() } -func renderDelay(d time.Duration) <-chan struct{} { +// renderDelay returns a channel that will be closed after d, +// or if ctx is done. +func renderDelay(ctx context.Context, d time.Duration) <-chan struct{} { ch := make(chan struct{}) - time.AfterFunc(d, func() { - close(ch) - }) + + go func() { + defer close(ch) + t := time.NewTimer(d) + defer t.Stop() + select { + case <-ctx.Done(): + case <-t.C: + } + }() return ch } -func ColorMeta(decorator decor.Decorator, c *color.Color) decor.Decorator { +func colorize(decorator decor.Decorator, c *color.Color) decor.Decorator { return decor.Meta(decorator, func(s string) string { return c.Sprint(s) }) } + +// DefaultColors returns the default colors used for the progress bars. +func DefaultColors() *Colors { + return &Colors{ + Message: color.New(color.Faint), + Spinner: color.New(color.FgGreen, color.Bold), + Size: color.New(color.Faint), + Percent: color.New(color.FgCyan, color.Faint), + } +} + +// Colors is the set of colors used for the progress bars. +type Colors struct { + Message *color.Color + Spinner *color.Color + Size *color.Color + Percent *color.Color +} + +// EnableColor enables or disables color for the progress bars. +func (c *Colors) EnableColor(enable bool) { + if c == nil { + return + } + + if enable { + c.Message.EnableColor() + c.Spinner.EnableColor() + c.Size.EnableColor() + c.Percent.EnableColor() + return + } + + c.Message.DisableColor() + c.Spinner.DisableColor() + c.Size.DisableColor() + c.Percent.EnableColor() +} diff --git a/libsq/source/files.go b/libsq/source/files.go index fdd0f691b..b0a8a278d 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -11,8 +11,6 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/h2non/filetype" "github.com/h2non/filetype/matchers" "golang.org/x/sync/errgroup" @@ -21,6 +19,7 @@ import ( "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" From 03e80a39dc25a1563e5e177082dbf9113320bc5f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 28 Nov 2023 21:14:00 -0700 Subject: [PATCH 028/195] dialing in progress pkg --- cli/cmd_xtest.go | 2 +- libsq/core/progress/progress.go | 91 ++++++++++++++++++++----------- libsq/core/progress/progressio.go | 4 +- 3 files changed, 63 insertions(+), 34 deletions(-) diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go index 625711535..75048f96e 100644 --- a/cli/cmd_xtest.go +++ b/cli/cmd_xtest.go @@ -50,7 +50,7 @@ func execXTestMbp(cmd *cobra.Command, _ []string) error { func doBigRead2(ctx context.Context) error { pb := progress.FromContext(ctx) - spinner := pb.NewByteCounterSpinner("Ingest data test", -1) + spinner := pb.NewByteCounter("Ingest data test", -1) defer spinner.Stop() maxSleep := 100 * time.Millisecond diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index c7851a46a..20d9187c4 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -47,13 +47,16 @@ func FromContext(ctx context.Context) *Progress { } const ( - barWidth = 28 - boxWidth = 64 + barWidth = 28 + boxWidth = 64 + refreshRate = 150 * time.Millisecond ) // New returns a new Progress instance, which is a container for progress bars. -// The returned Progress instance is safe for concurrent use. The caller is +// The returned Progress instance is safe for concurrent use, and all of its +// public methods can be safely invoked on a nil Progress. The caller is // responsible for calling [Progress.Wait] on the returned Progress. +// Arg delay specifies a duration to wait before rendering the progress bar. // The Progress is lazily initialized, and thus the delay clock doesn't // start ticking until the first call to one of the Progress.NewX methods. func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors) *Progress { @@ -64,7 +67,7 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors colors = DefaultColors() } - prog := &Progress{ + p := &Progress{ ctx: ctx, mu: sync.Mutex{}, colors: colors, @@ -72,18 +75,20 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors cancelFn: cancelFn, } - prog.pcInit = func() { + p.pcInit = func() { opts := []mpb.ContainerOption{ mpb.WithOutput(out), mpb.WithWidth(boxWidth), + mpb.WithRefreshRate(refreshRate), + mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } if delay > 0 { opts = append(opts, mpb.WithRenderDelay(renderDelay(ctx, delay))) } - prog.pc = mpb.NewWithContext(ctx, opts...) - prog.pcInit = nil + p.pc = mpb.NewWithContext(ctx, opts...) + p.pcInit = nil } - return prog + return p } // Progress represents a container that renders one or more progress bars. @@ -141,19 +146,19 @@ func (p *Progress) Wait() { // it will be called automatically when the Progress is shut down, but that // may be later than the actual conclusion of the spinner's work. // -// pbar := p.NewUnitCounter("Ingest records", "record") -// defer pbar.Stop() +// bar := p.NewUnitCounter("Ingest records", "rec") +// defer bar.Stop() // // for i := 0; i < 100; i++ { -// pbar.IncrBy(1) +// bar.IncrBy(1) // time.Sleep(100 * time.Millisecond) // } // // This produces output similar to: // -// Ingesting records ∙∙● 87 records +// Ingesting records ∙∙● 87 recs // -// Note that the unit arg is pluralized. +// Note that the unit arg is automatically pluralized. func (p *Progress) NewUnitCounter(msg, unit string) *Bar { if p == nil { return nil @@ -169,18 +174,20 @@ func (p *Progress) NewUnitCounter(msg, unit string) *Bar { } return s }) - decorator = colorize(decorator, p.colors.Size) - return p.newBar(msg, -1, decorator) + + style := spinnerStyle(p.colors.Filler) + + return p.newBar(msg, -1, style, decorator) } -// NewByteCounterSpinner returns a new spinner bar whose metric is the count +// NewByteCounter returns a new progress bar whose metric is the count // of bytes processed. If the size is unknown, set arg size to -1. The caller // is ultimately responsible for calling [Bar.Stop] on the returned Bar. // However, the returned Bar is also added to the Progress's cleanup list, // so it will be called automatically when the Progress is shut down, but that -// may be later than the actual conclusion of the spinner's work. -func (p *Progress) NewByteCounterSpinner(msg string, size int64) *Bar { +// may be later than the actual conclusion of the Bar's work. +func (p *Progress) NewByteCounter(msg string, size int64) *Bar { if p == nil { return nil } @@ -188,23 +195,28 @@ func (p *Progress) NewByteCounterSpinner(msg string, size int64) *Bar { p.mu.Lock() defer p.mu.Unlock() + var style mpb.BarFillerBuilder var counter decor.Decorator + var percent decor.Decorator if size < 0 { + style = spinnerStyle(p.colors.Filler) counter = decor.Current(decor.SizeB1024(0), "% .1f") } else { + style = barStyle(p.colors.Filler) counter = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") + percent = decor.NewPercentage(" %.1f", decor.WCSyncSpace) + percent = colorize(percent, p.colors.Percent) } counter = colorize(counter, p.colors.Size) - percent := decor.NewPercentage(" %.1f", decor.WCSyncSpace) - percent = colorize(percent, p.colors.Percent) - - return p.newBar(msg, size, counter, percent) + return p.newBar(msg, size, style, counter, percent) } // newBar returns a new Bar. This function must only be called from // inside the mutex. -func (p *Progress) newBar(msg string, total int64, decorators ...decor.Decorator) *Bar { +func (p *Progress) newBar(msg string, total int64, + style mpb.BarFillerBuilder, decorators ...decor.Decorator, +) *Bar { if p == nil { return nil } @@ -223,8 +235,6 @@ func (p *Progress) newBar(msg string, total int64, decorators ...decor.Decorator total = 0 } - style := spinnerStyle(p.colors.Spinner) - bar := p.pc.New(total, style, mpb.BarWidth(barWidth), @@ -241,7 +251,9 @@ func (p *Progress) newBar(msg string, total int64, decorators ...decor.Decorator } func spinnerStyle(c *color.Color) mpb.SpinnerStyleComposer { - style := mpb.SpinnerStyle("∙∙∙", "●∙∙", "∙●∙", "∙∙●", "∙∙∙") + // TODO: should use ascii chars? + frames := []string{"∙∙∙", "●∙∙", "●∙∙", "∙●∙", "∙●∙", "∙∙●", "∙∙●", "∙∙∙"} + style := mpb.SpinnerStyle(frames...) if c != nil { style = style.Meta(func(s string) string { return c.Sprint(s) @@ -250,6 +262,23 @@ func spinnerStyle(c *color.Color) mpb.SpinnerStyleComposer { return style } +func barStyle(c *color.Color) mpb.BarStyleComposer { + clr := func(s string) string { + return c.Sprint(s) + } + + frames := []string{"∙", "●", "●", "●", "∙"} + //frames := []string{"∙∙∙", "●∙∙", "●∙∙", "∙●∙", "∙●∙", "∙∙●", "∙∙●", "∙∙∙"} + + return mpb.BarStyle(). + Lbound(" ").Rbound(" "). + Filler("∙").FillerMeta(clr). + //Refiller("x").RefillerMeta(clr). + Padding(" "). + //Tip(`-`, `\`, `|`, `/`).TipMeta(clr). + Tip(frames...).TipMeta(clr) +} + // Bar represents a single progress bar. The caller should invoke // [Bar.IncrBy] as necessary to increment the bar's progress. When // the bar is complete, the caller should invoke [Bar.Stop]. All @@ -306,7 +335,7 @@ func colorize(decorator decor.Decorator, c *color.Color) decor.Decorator { func DefaultColors() *Colors { return &Colors{ Message: color.New(color.Faint), - Spinner: color.New(color.FgGreen, color.Bold), + Filler: color.New(color.FgGreen, color.Bold, color.Faint), Size: color.New(color.Faint), Percent: color.New(color.FgCyan, color.Faint), } @@ -315,7 +344,7 @@ func DefaultColors() *Colors { // Colors is the set of colors used for the progress bars. type Colors struct { Message *color.Color - Spinner *color.Color + Filler *color.Color Size *color.Color Percent *color.Color } @@ -328,14 +357,14 @@ func (c *Colors) EnableColor(enable bool) { if enable { c.Message.EnableColor() - c.Spinner.EnableColor() + c.Filler.EnableColor() c.Size.EnableColor() c.Percent.EnableColor() return } c.Message.DisableColor() - c.Spinner.DisableColor() + c.Filler.DisableColor() c.Size.DisableColor() - c.Percent.EnableColor() + c.Percent.DisableColor() } diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index a85332785..462ee823f 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -54,7 +54,7 @@ func NewWriter(ctx context.Context, msg string, size int64, w io.Writer) io.Writ return contextio.NewWriter(ctx, w) } - spinner := pb.NewByteCounterSpinner(msg, size) + spinner := pb.NewByteCounter(msg, size) return &progCopier{progWriter{ ctx: ctx, w: spinner.bar.ProxyWriter(w), @@ -135,7 +135,7 @@ func NewReader(ctx context.Context, msg string, size int64, r io.Reader) io.Read return contextio.NewReader(ctx, r) } - spinner := pb.NewByteCounterSpinner(msg, size) + spinner := pb.NewByteCounter(msg, size) pr := &progReader{ ctx: ctx, r: spinner.bar.ProxyReader(r), From 12a18b0a6a71835689a9b4ecb38da663860d88d3 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 30 Nov 2023 09:15:44 -0700 Subject: [PATCH 029/195] Switched to using fscache.MapFile mechanism --- go.mod | 2 +- libsq/source/files.go | 34 +++++++++++++++++----------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/go.mod b/go.mod index ccc4a9ab8..20692ed57 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/alessio/shellescape v1.4.2 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/c2h5oh/datasize v0.0.0-20220606134207-859f65c6625b + github.com/dustin/go-humanize v1.0.1 github.com/ecnepsnai/osquery v1.0.1 github.com/emirpasic/gods v1.18.1 github.com/fatih/color v1.16.0 @@ -59,7 +60,6 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/djherbis/atime v1.1.0 // indirect github.com/djherbis/stream v1.4.0 // indirect - github.com/dustin/go-humanize v1.0.1 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/huandu/xstrings v1.4.0 // indirect diff --git a/libsq/source/files.go b/libsq/source/files.go index b0a8a278d..10269cb54 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -143,29 +143,29 @@ func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.R log := lg.FromContext(ctx) log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) - r, w, err := fs.fcache.Get(key) - if err != nil { - return nil, errz.Err(err) - } - if w == nil { - lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - return nil, errz.Errorf("failed to add to fscache (possibly previously added): %s", key) + if key != StdinHandle { + if fs.fcache.Exists(key) { + return nil, errz.Errorf("file already exists in cache: %s", key) + } + + if err := fs.fcache.MapFile(f.Name()); err != nil { + return nil, errz.Wrapf(err, "failed to map file into fscache: %s", f.Name()) + } + + r, _, err := fs.fcache.Get(key) + return r, errz.Err(err) } - fi, err := f.Stat() + // Special handling for stdin + r, w, err := fs.fcache.Get(StdinHandle) if err != nil { return nil, errz.Err(err) } - size := fi.Size() - if size == 0 { - size = -1 + if w == nil { + lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) + return nil, errz.Errorf("fscache: no writer for %s", StdinHandle) } - // TODO: Problematically, we copy the entire contents of f into fscache. - // This is probably necessary for piped data on stdin, but for files - // that already exist on the file system, it would be nice if the cacheFile - // could be mapped directly to the filesystem file. This might require - // hacking on the fscache impl. copier := fscache.AsyncFiller{ Message: "Reading source data", Log: log.With(lga.Action, "Cache fill"), @@ -179,7 +179,7 @@ func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.R } df := ioz.DelayReader(f, time.Millisecond, true) // FIXME: Delete - if err = copier.Copy(ctx, size, w, df); err != nil { + if err = copier.Copy(ctx, -1, w, df); err != nil { lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) return nil, errz.Err(err) } From 016d6493a011e642ef42d783efc552b44c60a059 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 2 Dec 2023 08:11:41 -0700 Subject: [PATCH 030/195] wip --- cli/logging.go | 4 +- cli/source.go | 2 +- drivers/csv/csv.go | 4 +- drivers/csv/detect_type.go | 7 + drivers/xlsx/detect.go | 26 +- go.mod | 1 + go.sum | 2 + libsq/core/ioz/ioz.go | 50 ++ libsq/core/lg/devlog/devlog.go | 55 +++ libsq/core/lg/devlog/tint/LICENSE | 21 + libsq/core/lg/devlog/tint/README.md | 88 ++++ libsq/core/lg/devlog/tint/buffer.go | 46 ++ libsq/core/lg/devlog/tint/handler.go | 454 +++++++++++++++++ libsq/core/lg/devlog/tint/handler_test.go | 572 ++++++++++++++++++++++ libsq/core/progress/progressio.go | 70 ++- libsq/source/files.go | 332 +++++++++++-- libsq/source/files_test.go | 6 +- 17 files changed, 1681 insertions(+), 59 deletions(-) create mode 100644 libsq/core/lg/devlog/devlog.go create mode 100644 libsq/core/lg/devlog/tint/LICENSE create mode 100644 libsq/core/lg/devlog/tint/README.md create mode 100644 libsq/core/lg/devlog/tint/buffer.go create mode 100644 libsq/core/lg/devlog/tint/handler.go create mode 100644 libsq/core/lg/devlog/tint/handler_test.go diff --git a/cli/logging.go b/cli/logging.go index 27d75fcbb..35e3221c1 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -2,6 +2,7 @@ package cli import ( "context" + "github.com/neilotoole/sq/libsq/core/lg/devlog" "io" "log/slog" "os" @@ -91,7 +92,8 @@ func defaultLogging(ctx context.Context, osArgs []string, cfg *config.Config, } closer = logFile.Close - h = newJSONHandler(logFile, lvl) + h = devlog.NewHandler(logFile, lvl) + //h = newJSONHandler(logFile, lvl) return slog.New(h), h, closer, nil } diff --git a/cli/source.go b/cli/source.go index fcb177d53..e96da1a4f 100644 --- a/cli/source.go +++ b/cli/source.go @@ -187,7 +187,7 @@ func checkStdinSource(ctx context.Context, ru *run.Run) (*source.Source, error) } if typ == drivertype.None { - typ, err = ru.Files.TypeStdin(ctx) + typ, err = ru.Files.DetectStdinType(ctx) if err != nil { return nil, err } diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index ae5c60c51..a720dd22e 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -79,7 +79,9 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) ingestFn := func(ctx context.Context, destPool driver.Pool) error { - return ingestCSV(ctx, src, d.files.OpenFunc(src), destPool) + openFn := d.files.OpenFunc(src) + log.Debug("Ingest func invoked", lga.Src, src) + return ingestCSV(ctx, src, openFn, destPool) } backingPool, err := d.scratcher.OpenIngest(ctx, src, ingestFn, allowCache) diff --git a/drivers/csv/detect_type.go b/drivers/csv/detect_type.go index 7fd7dd29d..7e9db5171 100644 --- a/drivers/csv/detect_type.go +++ b/drivers/csv/detect_type.go @@ -4,7 +4,9 @@ import ( "context" "encoding/csv" "errors" + "github.com/neilotoole/sq/libsq/core/lg/lga" "io" + "time" "github.com/neilotoole/sq/cli/output/csvw" "github.com/neilotoole/sq/libsq/core/errz" @@ -73,6 +75,11 @@ const ( // isCSV returns a score indicating the confidence that cr is reading // legitimate CSV, where a score <= 0 is not CSV, a score >= 1 is definitely CSV. func isCSV(ctx context.Context, cr *csv.Reader) (score float32) { + start := time.Now() + lg.FromContext(ctx).Debug("isCSV invoked", lga.Timestamp, start) + defer func() { + lg.FromContext(ctx).Debug("isCSV complete", "elapsed", time.Since(start), "score", score) + }() const ( maxRecords int = 100 ) diff --git a/drivers/xlsx/detect.go b/drivers/xlsx/detect.go index b3fcfe03d..b9885bc9a 100644 --- a/drivers/xlsx/detect.go +++ b/drivers/xlsx/detect.go @@ -2,11 +2,12 @@ package xlsx import ( "context" + "errors" + "github.com/h2non/filetype" + "github.com/h2non/filetype/matchers" "io" "slices" - excelize "github.com/xuri/excelize/v2" - "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" @@ -23,6 +24,8 @@ var _ source.DriverDetectFunc = DetectXLSX func DetectXLSX(ctx context.Context, openFn source.FileOpenFunc) (detected drivertype.Type, score float32, err error, ) { + const detectBufSize = 4096 + log := lg.FromContext(ctx) var r io.ReadCloser r, err = openFn(ctx) @@ -31,14 +34,23 @@ func DetectXLSX(ctx context.Context, openFn source.FileOpenFunc) (detected drive } defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - f, err := excelize.OpenReader(r) - if err != nil { - return drivertype.None, 0, nil + buf := make([]byte, detectBufSize) + if _, err = r.Read(buf); err != nil { + return drivertype.None, 0, errz.Err(err) } - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, f) + t, err := filetype.Document(buf) + if err != nil && !errors.Is(err, filetype.ErrUnknownBuffer) { + return drivertype.None, 0, errz.Err(err) + } + + switch t { + case matchers.TypeXlsx, matchers.TypeXls: + return Type, 1.0, nil + default: + return drivertype.None, 0, nil + } - return Type, 1.0, nil } func detectHeaderRow(ctx context.Context, sheet *xSheet) (hasHeader bool, err error) { diff --git a/go.mod b/go.mod index 20692ed57..6b1717746 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/google/uuid v1.4.0 github.com/h2non/filetype v1.1.3 github.com/jackc/pgx/v5 v5.5.0 + github.com/lmittmann/tint v1.0.3 github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.20 github.com/mattn/go-runewidth v0.0.15 diff --git a/go.sum b/go.sum index ecd3e4b48..0215ae65a 100644 --- a/go.sum +++ b/go.sum @@ -97,6 +97,8 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/lmittmann/tint v1.0.3 h1:W5PHeA2D8bBJVvabNfQD/XW9HPLZK1XoPZH0cq8NouQ= +github.com/lmittmann/tint v1.0.3/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 663a7b3dc..aa5d8f698 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -37,6 +37,56 @@ func Close(ctx context.Context, c io.Closer) { lg.WarnIfError(log, "Close", err) } +// CopyAsync asynchronously copies from r to w, invoking +// non-nil callback when done. +func CopyAsync(w io.Writer, r io.Reader, callback func(written int64, err error)) { + go func() { + written, err := io.Copy(w, r) + if callback != nil { + callback(written, err) + } + }() +} + +// CopyAsyncFull asynchronously copies from r to w, invoking callback when done. +// If arg close is true and w is an io.WriterCloser, w is closed on successful +// completion of the copy (but it is not closed if an error occurs during write). +// If callback is nil, it is ignored. +func CopyAsyncFull(w io.Writer, r io.Reader, close bool, callback func(written int64, err error)) { + go func() { + written, err := io.Copy(w, r) + if err != nil { + if callback != nil { + callback(written, err) + } + return + } + // err is nil + if !close { + if callback != nil { + callback(written, nil) + } + return + } + + // err is nil, and close is true + wc, ok := w.(io.WriteCloser) + if !ok { + // It's not a write closer... this is basically a programming + // error to have set close to true when w is not an io.WriteCloser, + // but we'll handle it generously. + if callback != nil { + callback(written, nil) + } + return + } + + // err is nil, close is true, and w is a WriteCloser + err = wc.Close() + callback(written, err) + }() +} + // PrintFile reads file from name and writes it to stdout. func PrintFile(name string) error { return FPrintFile(os.Stdout, name) diff --git a/libsq/core/lg/devlog/devlog.go b/libsq/core/lg/devlog/devlog.go new file mode 100644 index 000000000..800420cab --- /dev/null +++ b/libsq/core/lg/devlog/devlog.go @@ -0,0 +1,55 @@ +package devlog + +import ( + "github.com/neilotoole/sq/libsq/core/lg/devlog/tint" + "io" + "log/slog" + "path/filepath" + "strconv" + "strings" +) + +const shortTimeFormat = `15:04:05.000000` + +// New returns a developer-friendly logger that +// logs to w. +func New(w io.Writer, lvl slog.Leveler) *slog.Logger { + h := NewHandler(w, lvl) + return slog.New(h) +} + +func NewHandler(w io.Writer, lvl slog.Leveler) slog.Handler { + h := tint.NewHandler(w, &tint.Options{ + Level: lvl, + TimeFormat: shortTimeFormat, + AddSource: true, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + switch a.Key { + case "pid": + return slog.Attr{} + default: + return a + } + }, + }) + return h +} + +// replaceSourceShort prints a dev-friendly "source" field. +func replaceSourceShort(_ []string, a slog.Attr) slog.Attr { + if src, ok := a.Value.Any().(*slog.Source); ok { + s := filepath.Join(filepath.Base(filepath.Dir(src.File)), filepath.Base(src.File)) + s += ":" + strconv.Itoa(src.Line) + + fn := src.Function + parts := strings.Split(src.Function, "/") + if len(parts) > 0 { + fn = parts[len(parts)-1] + } + + s += ":" + fn + //a.Key = "src" + a.Value = slog.StringValue(s) + } + return a +} diff --git a/libsq/core/lg/devlog/tint/LICENSE b/libsq/core/lg/devlog/tint/LICENSE new file mode 100644 index 000000000..3f49589dc --- /dev/null +++ b/libsq/core/lg/devlog/tint/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 lmittmann + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/libsq/core/lg/devlog/tint/README.md b/libsq/core/lg/devlog/tint/README.md new file mode 100644 index 000000000..88a77acb4 --- /dev/null +++ b/libsq/core/lg/devlog/tint/README.md @@ -0,0 +1,88 @@ +# `tint`: 🌈 **slog.Handler** that writes tinted logs + +[![Go Reference](https://pkg.go.dev/badge/github.com/lmittmann/tint.svg)](https://pkg.go.dev/github.com/lmittmann/tint#section-documentation) +[![Go Report Card](https://goreportcard.com/badge/github.com/lmittmann/tint)](https://goreportcard.com/report/github.com/lmittmann/tint) + + + + + + +
+
+ +Package `tint` implements a zero-dependency [`slog.Handler`](https://pkg.go.dev/log/slog#Handler) +that writes tinted (colorized) logs. Its output format is inspired by the `zerolog.ConsoleWriter` and +[`slog.TextHandler`](https://pkg.go.dev/log/slog#TextHandler). + +The output format can be customized using [`Options`](https://pkg.go.dev/github.com/lmittmann/tint#Options) +which is a drop-in replacement for [`slog.HandlerOptions`](https://pkg.go.dev/log/slog#HandlerOptions). + +``` +go get github.com/lmittmann/tint +``` + +## Usage + +```go +w := os.Stderr + +// create a new logger +logger := slog.New(tint.NewHandler(w, nil)) + +// set global logger with custom options +slog.SetDefault(slog.New( + tint.NewHandler(w, &tint.Options{ + Level: slog.LevelDebug, + TimeFormat: time.Kitchen, + }), +)) +``` + +### Customize Attributes + +`ReplaceAttr` can be used to alter or drop attributes. If set, it is called on +each non-group attribute before it is logged. See [`slog.HandlerOptions`](https://pkg.go.dev/log/slog#HandlerOptions) +for details. + +```go +// create a new logger that doesn't write the time +w := os.Stderr +logger := slog.New( + tint.NewHandler(w, &tint.Options{ + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey && len(groups) == 0 { + return slog.Attr{} + } + return a + }, + }), +) +``` + +### Automatically Enable Colors + +Colors are enabled by default and can be disabled using the `Options.NoColor` +attribute. To automatically enable colors based on the terminal capabilities, +use e.g. the [`go-isatty`](https://github.com/mattn/go-isatty) package. + +```go +w := os.Stderr +logger := slog.New( + tint.NewHandler(w, &tint.Options{ + NoColor: !isatty.IsTerminal(w.Fd()), + }), +) +``` + +### Windows Support + +Color support on Windows can be added by using e.g. the +[`go-colorable`](https://github.com/mattn/go-colorable) package. + +```go +w := os.Stderr +logger := slog.New( + tint.NewHandler(colorable.NewColorable(w), nil), +) +``` diff --git a/libsq/core/lg/devlog/tint/buffer.go b/libsq/core/lg/devlog/tint/buffer.go new file mode 100644 index 000000000..4d7321a6c --- /dev/null +++ b/libsq/core/lg/devlog/tint/buffer.go @@ -0,0 +1,46 @@ +package tint + +import "sync" + +type buffer []byte + +var bufPool = sync.Pool{ + New: func() any { + b := make(buffer, 0, 1024) + return (*buffer)(&b) + }, +} + +func newBuffer() *buffer { + return bufPool.Get().(*buffer) +} + +func (b *buffer) Free() { + // To reduce peak allocation, return only smaller buffers to the pool. + const maxBufferSize = 16 << 10 + if cap(*b) <= maxBufferSize { + *b = (*b)[:0] + bufPool.Put(b) + } +} +func (b *buffer) Write(bytes []byte) (int, error) { + *b = append(*b, bytes...) + return len(bytes), nil +} + +func (b *buffer) WriteByte(char byte) error { + *b = append(*b, char) + return nil +} + +func (b *buffer) WriteString(str string) (int, error) { + *b = append(*b, str...) + return len(str), nil +} + +func (b *buffer) WriteStringIf(ok bool, str string) (int, error) { + if !ok { + return 0, nil + } + return b.WriteString(str) +} diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go new file mode 100644 index 000000000..89ad8ad00 --- /dev/null +++ b/libsq/core/lg/devlog/tint/handler.go @@ -0,0 +1,454 @@ +/* +Package tint implements a zero-dependency [slog.Handler] that writes tinted +(colorized) logs. The output format is inspired by the [zerolog.ConsoleWriter] +and [slog.TextHandler]. + +The output format can be customized using [Options], which is a drop-in +replacement for [slog.HandlerOptions]. + +# Customize Attributes + +Options.ReplaceAttr can be used to alter or drop attributes. If set, it is +called on each non-group attribute before it is logged. +See [slog.HandlerOptions] for details. + + w := os.Stderr + logger := slog.New( + tint.NewHandler(w, &tint.Options{ + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey && len(groups) == 0 { + return slog.Attr{} + } + return a + }, + }), + ) + +# Automatically Enable Colors + +Colors are enabled by default and can be disabled using the Options.NoColor +attribute. To automatically enable colors based on the terminal capabilities, +use e.g. the [go-isatty] package. + + w := os.Stderr + logger := slog.New( + tint.NewHandler(w, &tint.Options{ + NoColor: !isatty.IsTerminal(w.Fd()), + }), + ) + +# Windows Support + +Color support on Windows can be added by using e.g. the [go-colorable] package. + + w := os.Stderr + logger := slog.New( + tint.NewHandler(colorable.NewColorable(w), nil), + ) + +[zerolog.ConsoleWriter]: https://pkg.go.dev/github.com/rs/zerolog#ConsoleWriter +[go-isatty]: https://pkg.go.dev/github.com/mattn/go-isatty +[go-colorable]: https://pkg.go.dev/github.com/mattn/go-colorable +*/ +package tint + +import ( + "context" + "encoding" + "fmt" + "io" + "log/slog" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "time" + "unicode" +) + +// ANSI modes +// See: https://gist.github.com/JBlond/2fea43a3049b38287e5e9cefc87b2124 +const ( + ansiReset = "\033[0m" + ansiFaint = "\033[2m" + ansiResetFaint = "\033[22m" + ansiBrightRed = "\033[91m" + ansiBrightGreen = "\033[92m" + ansiBrightYellow = "\033[93m" + ansiBlue = "\033[34m" + ansiBrightRedFaint = "\033[91;2m" +) + +const errKey = "err" + +var ( + defaultLevel = slog.LevelInfo + defaultTimeFormat = time.StampMilli +) + +// Options for a slog.Handler that writes tinted logs. A zero Options consists +// entirely of default values. +// +// Options can be used as a drop-in replacement for [slog.HandlerOptions]. +type Options struct { + // Enable source code location (Default: false) + AddSource bool + + // Minimum level to log (Default: slog.LevelInfo) + Level slog.Leveler + + // ReplaceAttr is called to rewrite each non-group attribute before it is logged. + // See https://pkg.go.dev/log/slog#HandlerOptions for details. + ReplaceAttr func(groups []string, attr slog.Attr) slog.Attr + + // Time format (Default: time.StampMilli) + TimeFormat string + + // Disable color (Default: false) + NoColor bool +} + +// NewHandler creates a [slog.Handler] that writes tinted logs to Writer w, +// using the default options. If opts is nil, the default options are used. +func NewHandler(w io.Writer, opts *Options) slog.Handler { + h := &handler{ + w: w, + level: defaultLevel, + timeFormat: defaultTimeFormat, + } + if opts == nil { + return h + } + + h.addSource = opts.AddSource + if opts.Level != nil { + h.level = opts.Level + } + h.replaceAttr = opts.ReplaceAttr + if opts.TimeFormat != "" { + h.timeFormat = opts.TimeFormat + } + h.noColor = opts.NoColor + return h +} + +// handler implements a [slog.Handler]. +type handler struct { + attrsPrefix string + groupPrefix string + groups []string + + mu sync.Mutex + w io.Writer + + addSource bool + level slog.Leveler + replaceAttr func([]string, slog.Attr) slog.Attr + timeFormat string + noColor bool +} + +func (h *handler) clone() *handler { + return &handler{ + attrsPrefix: h.attrsPrefix, + groupPrefix: h.groupPrefix, + groups: h.groups, + w: h.w, + addSource: h.addSource, + level: h.level, + replaceAttr: h.replaceAttr, + timeFormat: h.timeFormat, + noColor: h.noColor, + } +} + +func (h *handler) Enabled(_ context.Context, level slog.Level) bool { + return level >= h.level.Level() +} + +func (h *handler) Handle(_ context.Context, r slog.Record) error { + // get a buffer from the sync pool + buf := newBuffer() + defer buf.Free() + + rep := h.replaceAttr + + // write time + if !r.Time.IsZero() { + val := r.Time.Round(0) // strip monotonic to match Attr behavior + if rep == nil { + h.appendTime(buf, r.Time) + buf.WriteByte(' ') + } else if a := rep(nil /* groups */, slog.Time(slog.TimeKey, val)); a.Key != "" { + if a.Value.Kind() == slog.KindTime { + h.appendTime(buf, a.Value.Time()) + } else { + h.appendValue(buf, a.Value, false) + } + buf.WriteByte(' ') + } + } + + // write level + if rep == nil { + h.appendLevel(buf, r.Level) + buf.WriteByte(' ') + } else if a := rep(nil /* groups */, slog.Any(slog.LevelKey, r.Level)); a.Key != "" { + h.appendValue(buf, a.Value, false) + buf.WriteByte(' ') + } + + // write source + if h.addSource { + fs := runtime.CallersFrames([]uintptr{r.PC}) + f, _ := fs.Next() + if f.File != "" { + src := &slog.Source{ + Function: f.Function, + File: f.File, + Line: f.Line, + } + + if rep == nil { + h.appendSource(buf, src) + buf.WriteByte(' ') + } else if a := rep(nil /* groups */, slog.Any(slog.SourceKey, src)); a.Key != "" { + h.appendValue(buf, a.Value, false) + buf.WriteByte(' ') + } + } + } + + // write message + if rep == nil { + buf.WriteString(r.Message) + buf.WriteByte(' ') + } else if a := rep(nil /* groups */, slog.String(slog.MessageKey, r.Message)); a.Key != "" { + h.appendValue(buf, a.Value, false) + buf.WriteByte(' ') + } + + // write handler attributes + if len(h.attrsPrefix) > 0 { + buf.WriteString(h.attrsPrefix) + } + + // write attributes + r.Attrs(func(attr slog.Attr) bool { + h.appendAttr(buf, attr, h.groupPrefix, h.groups) + return true + }) + + if len(*buf) == 0 { + return nil + } + (*buf)[len(*buf)-1] = '\n' // replace last space with newline + + h.mu.Lock() + defer h.mu.Unlock() + + _, err := h.w.Write(*buf) + return err +} + +func (h *handler) WithAttrs(attrs []slog.Attr) slog.Handler { + if len(attrs) == 0 { + return h + } + h2 := h.clone() + + buf := newBuffer() + defer buf.Free() + + // write attributes to buffer + for _, attr := range attrs { + h.appendAttr(buf, attr, h.groupPrefix, h.groups) + } + h2.attrsPrefix = h.attrsPrefix + string(*buf) + return h2 +} + +func (h *handler) WithGroup(name string) slog.Handler { + if name == "" { + return h + } + h2 := h.clone() + h2.groupPrefix += name + "." + h2.groups = append(h2.groups, name) + return h2 +} + +func (h *handler) appendTime(buf *buffer, t time.Time) { + buf.WriteStringIf(!h.noColor, ansiFaint) + *buf = t.AppendFormat(*buf, h.timeFormat) + buf.WriteStringIf(!h.noColor, ansiReset) +} + +func (h *handler) appendLevel(buf *buffer, level slog.Level) { + switch { + case level < slog.LevelInfo: + buf.WriteString("DBG") + appendLevelDelta(buf, level-slog.LevelDebug) + case level < slog.LevelWarn: + buf.WriteStringIf(!h.noColor, ansiBrightGreen) + buf.WriteString("INF") + appendLevelDelta(buf, level-slog.LevelInfo) + buf.WriteStringIf(!h.noColor, ansiReset) + case level < slog.LevelError: + buf.WriteStringIf(!h.noColor, ansiBrightYellow) + buf.WriteString("WRN") + appendLevelDelta(buf, level-slog.LevelWarn) + buf.WriteStringIf(!h.noColor, ansiReset) + default: + buf.WriteStringIf(!h.noColor, ansiBrightRed) + buf.WriteString("ERR") + appendLevelDelta(buf, level-slog.LevelError) + buf.WriteStringIf(!h.noColor, ansiReset) + } +} + +func appendLevelDelta(buf *buffer, delta slog.Level) { + if delta == 0 { + return + } else if delta > 0 { + buf.WriteByte('+') + } + *buf = strconv.AppendInt(*buf, int64(delta), 10) +} + +func (h *handler) appendSource(buf *buffer, src *slog.Source) { + dir, file := filepath.Split(src.File) + + fn := src.Function + parts := strings.Split(src.Function, "/") + if len(parts) > 0 { + fn = parts[len(parts)-1] + } + + if fn != "" { + buf.WriteStringIf(!h.noColor, ansiBlue) + buf.WriteString(fn) + buf.WriteStringIf(!h.noColor, ansiReset) + buf.WriteByte(' ') + } + buf.WriteStringIf(!h.noColor, ansiFaint) + buf.WriteString(filepath.Join(filepath.Base(dir), file)) + buf.WriteByte(':') + buf.WriteString(strconv.Itoa(src.Line)) + buf.WriteStringIf(!h.noColor, ansiReset) + +} + +func (h *handler) appendAttr(buf *buffer, attr slog.Attr, groupsPrefix string, groups []string) { + attr.Value = attr.Value.Resolve() + if rep := h.replaceAttr; rep != nil && attr.Value.Kind() != slog.KindGroup { + attr = rep(groups, attr) + attr.Value = attr.Value.Resolve() + } + + if attr.Equal(slog.Attr{}) { + return + } + + if attr.Value.Kind() == slog.KindGroup { + if attr.Key != "" { + groupsPrefix += attr.Key + "." + groups = append(groups, attr.Key) + } + for _, groupAttr := range attr.Value.Group() { + h.appendAttr(buf, groupAttr, groupsPrefix, groups) + } + } else if err, ok := attr.Value.Any().(tintError); ok { + // append tintError + h.appendTintError(buf, err, groupsPrefix) + buf.WriteByte(' ') + } else { + h.appendKey(buf, attr.Key, groupsPrefix) + h.appendValue(buf, attr.Value, true) + buf.WriteByte(' ') + } +} + +func (h *handler) appendKey(buf *buffer, key, groups string) { + buf.WriteStringIf(!h.noColor, ansiFaint) + appendString(buf, groups+key, true) + buf.WriteByte('=') + buf.WriteStringIf(!h.noColor, ansiReset) +} + +func (h *handler) appendValue(buf *buffer, v slog.Value, quote bool) { + switch v.Kind() { + case slog.KindString: + appendString(buf, v.String(), quote) + case slog.KindInt64: + *buf = strconv.AppendInt(*buf, v.Int64(), 10) + case slog.KindUint64: + *buf = strconv.AppendUint(*buf, v.Uint64(), 10) + case slog.KindFloat64: + *buf = strconv.AppendFloat(*buf, v.Float64(), 'g', -1, 64) + case slog.KindBool: + *buf = strconv.AppendBool(*buf, v.Bool()) + case slog.KindDuration: + appendString(buf, v.Duration().String(), quote) + case slog.KindTime: + appendString(buf, v.Time().String(), quote) + case slog.KindAny: + switch cv := v.Any().(type) { + case slog.Level: + h.appendLevel(buf, cv) + case encoding.TextMarshaler: + data, err := cv.MarshalText() + if err != nil { + break + } + appendString(buf, string(data), quote) + case *slog.Source: + h.appendSource(buf, cv) + default: + appendString(buf, fmt.Sprint(v.Any()), quote) + } + } +} + +func (h *handler) appendTintError(buf *buffer, err error, groupsPrefix string) { + buf.WriteStringIf(!h.noColor, ansiBrightRedFaint) + appendString(buf, groupsPrefix+errKey, true) + buf.WriteByte('=') + buf.WriteStringIf(!h.noColor, ansiResetFaint) + appendString(buf, err.Error(), true) + buf.WriteStringIf(!h.noColor, ansiReset) +} + +func appendString(buf *buffer, s string, quote bool) { + if quote && needsQuoting(s) { + *buf = strconv.AppendQuote(*buf, s) + } else { + buf.WriteString(s) + } +} + +func needsQuoting(s string) bool { + if len(s) == 0 { + return true + } + for _, r := range s { + if unicode.IsSpace(r) || r == '"' || r == '=' || !unicode.IsPrint(r) { + return true + } + } + return false +} + +type tintError struct{ error } + +// Err returns a tinted (colorized) [slog.Attr] that will be written in red color +// by the [tint.Handler]. When used with any other [slog.Handler], it behaves as +// +// slog.Any("err", err) +func Err(err error) slog.Attr { + if err != nil { + err = tintError{err} + } + return slog.Any(errKey, err) +} diff --git a/libsq/core/lg/devlog/tint/handler_test.go b/libsq/core/lg/devlog/tint/handler_test.go new file mode 100644 index 000000000..5000e4a1c --- /dev/null +++ b/libsq/core/lg/devlog/tint/handler_test.go @@ -0,0 +1,572 @@ +package tint_test + +import ( + "bytes" + "context" + "errors" + "io" + "log/slog" + "os" + "slices" + "strconv" + "strings" + "testing" + "time" + + "github.com/lmittmann/tint" +) + +var faketime = time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) + +func Example() { + slog.SetDefault(slog.New(tint.NewHandler(os.Stderr, &tint.Options{ + Level: slog.LevelDebug, + TimeFormat: time.Kitchen, + }))) + + slog.Info("Starting server", "addr", ":8080", "env", "production") + slog.Debug("Connected to DB", "db", "myapp", "host", "localhost:5432") + slog.Warn("Slow request", "method", "GET", "path", "/users", "duration", 497*time.Millisecond) + slog.Error("DB connection lost", tint.Err(errors.New("connection reset")), "db", "myapp") + // Output: +} + +// Run test with "faketime" tag: +// +// TZ="" go test -tags=faketime +func TestHandler(t *testing.T) { + if !faketime.Equal(time.Now()) { + t.Skip(`skipping test; run with "-tags=faketime"`) + } + + tests := []struct { + Opts *tint.Options + F func(l *slog.Logger) + Want string + }{ + { + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `Nov 10 23:00:00.000 INF test key=val`, + }, + { + F: func(l *slog.Logger) { + l.Error("test", tint.Err(errors.New("fail"))) + }, + Want: `Nov 10 23:00:00.000 ERR test err=fail`, + }, + { + F: func(l *slog.Logger) { + l.Info("test", slog.Group("group", slog.String("key", "val"), tint.Err(errors.New("fail")))) + }, + Want: `Nov 10 23:00:00.000 INF test group.key=val group.err=fail`, + }, + { + F: func(l *slog.Logger) { + l.WithGroup("group").Info("test", "key", "val") + }, + Want: `Nov 10 23:00:00.000 INF test group.key=val`, + }, + { + F: func(l *slog.Logger) { + l.With("key", "val").Info("test", "key2", "val2") + }, + Want: `Nov 10 23:00:00.000 INF test key=val key2=val2`, + }, + { + F: func(l *slog.Logger) { + l.Info("test", "k e y", "v a l") + }, + Want: `Nov 10 23:00:00.000 INF test "k e y"="v a l"`, + }, + { + F: func(l *slog.Logger) { + l.WithGroup("g r o u p").Info("test", "key", "val") + }, + Want: `Nov 10 23:00:00.000 INF test "g r o u p.key"=val`, + }, + { + F: func(l *slog.Logger) { + l.Info("test", "slice", []string{"a", "b", "c"}, "map", map[string]int{"a": 1, "b": 2, "c": 3}) + }, + Want: `Nov 10 23:00:00.000 INF test slice="[a b c]" map="map[a:1 b:2 c:3]"`, + }, + { + Opts: &tint.Options{ + AddSource: true, + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `Nov 10 23:00:00.000 INF tint/handler_test.go:100 test key=val`, + }, + { + Opts: &tint.Options{ + TimeFormat: time.Kitchen, + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `11:00PM INF test key=val`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: drop(slog.TimeKey), + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `INF test key=val`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: drop(slog.LevelKey), + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `Nov 10 23:00:00.000 test key=val`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: drop(slog.MessageKey), + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `Nov 10 23:00:00.000 INF key=val`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: drop(slog.TimeKey, slog.LevelKey, slog.MessageKey), + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `key=val`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: drop("key"), + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `Nov 10 23:00:00.000 INF test`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: drop("key"), + }, + F: func(l *slog.Logger) { + l.WithGroup("group").Info("test", "key", "val", "key2", "val2") + }, + Want: `Nov 10 23:00:00.000 INF test group.key=val group.key2=val2`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == "key" && len(groups) == 1 && groups[0] == "group" { + return slog.Attr{} + } + return a + }, + }, + F: func(l *slog.Logger) { + l.WithGroup("group").Info("test", "key", "val", "key2", "val2") + }, + Want: `Nov 10 23:00:00.000 INF test group.key2=val2`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: replace(slog.IntValue(42), slog.TimeKey), + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `42 INF test key=val`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: replace(slog.StringValue("INFO"), slog.LevelKey), + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `Nov 10 23:00:00.000 INFO test key=val`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: replace(slog.IntValue(42), slog.MessageKey), + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: `Nov 10 23:00:00.000 INF 42 key=val`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: replace(slog.IntValue(42), "key"), + }, + F: func(l *slog.Logger) { + l.With("key", "val").Info("test", "key2", "val2") + }, + Want: `Nov 10 23:00:00.000 INF test key=42 key2=val2`, + }, + { + Opts: &tint.Options{ + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + return slog.Attr{} + }, + }, + F: func(l *slog.Logger) { + l.Info("test", "key", "val") + }, + Want: ``, + }, + { + F: func(l *slog.Logger) { + l.Info("test", "key", "") + }, + Want: `Nov 10 23:00:00.000 INF test key=""`, + }, + { + F: func(l *slog.Logger) { + l.Info("test", "", "val") + }, + Want: `Nov 10 23:00:00.000 INF test ""=val`, + }, + { + F: func(l *slog.Logger) { + l.Info("test", "", "") + }, + Want: `Nov 10 23:00:00.000 INF test ""=""`, + }, + + { // https://github.com/lmittmann/tint/issues/8 + F: func(l *slog.Logger) { + l.Log(context.TODO(), slog.LevelInfo+1, "test") + }, + Want: `Nov 10 23:00:00.000 INF+1 test`, + }, + { + Opts: &tint.Options{ + Level: slog.LevelDebug - 1, + }, + F: func(l *slog.Logger) { + l.Log(context.TODO(), slog.LevelDebug-1, "test") + }, + Want: `Nov 10 23:00:00.000 DBG-1 test`, + }, + { // https://github.com/lmittmann/tint/issues/12 + F: func(l *slog.Logger) { + l.Error("test", slog.Any("error", errors.New("fail"))) + }, + Want: `Nov 10 23:00:00.000 ERR test error=fail`, + }, + { // https://github.com/lmittmann/tint/issues/15 + F: func(l *slog.Logger) { + l.Error("test", tint.Err(nil)) + }, + Want: `Nov 10 23:00:00.000 ERR test err=`, + }, + { // https://github.com/lmittmann/tint/pull/26 + Opts: &tint.Options{ + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey && len(groups) == 0 { + return slog.Time(slog.TimeKey, a.Value.Time().Add(24*time.Hour)) + } + return a + }, + }, + F: func(l *slog.Logger) { + l.Error("test") + }, + Want: `Nov 11 23:00:00.000 ERR test`, + }, + { // https://github.com/lmittmann/tint/pull/27 + F: func(l *slog.Logger) { + l.Info("test", "a", "b", slog.Group("", slog.String("c", "d")), "e", "f") + }, + Want: `Nov 10 23:00:00.000 INF test a=b c=d e=f`, + }, + { // https://github.com/lmittmann/tint/pull/30 + // drop built-in attributes in a grouped log + Opts: &tint.Options{ + ReplaceAttr: drop(slog.TimeKey, slog.LevelKey, slog.MessageKey, slog.SourceKey), + AddSource: true, + }, + F: func(l *slog.Logger) { + l.WithGroup("group").Info("test", "key", "val") + }, + Want: `group.key=val`, + }, + { // https://github.com/lmittmann/tint/issues/36 + Opts: &tint.Options{ + ReplaceAttr: func(g []string, a slog.Attr) slog.Attr { + if len(g) == 0 && a.Key == slog.LevelKey { + _ = a.Value.Any().(slog.Level) + } + return a + }, + }, + F: func(l *slog.Logger) { + l.Info("test") + }, + Want: `Nov 10 23:00:00.000 INF test`, + }, + { // https://github.com/lmittmann/tint/issues/37 + Opts: &tint.Options{ + AddSource: true, + ReplaceAttr: func(g []string, a slog.Attr) slog.Attr { + return a + }, + }, + F: func(l *slog.Logger) { + l.Info("test") + }, + Want: `Nov 10 23:00:00.000 INF tint/handler_test.go:327 test`, + }, + { // https://github.com/lmittmann/tint/issues/44 + F: func(l *slog.Logger) { + l = l.WithGroup("group") + l.Error("test", tint.Err(errTest)) + }, + Want: `Nov 10 23:00:00.000 ERR test group.err=fail`, + }, + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + var buf bytes.Buffer + if test.Opts == nil { + test.Opts = &tint.Options{} + } + test.Opts.NoColor = true + l := slog.New(tint.NewHandler(&buf, test.Opts)) + test.F(l) + + got := strings.TrimRight(buf.String(), "\n") + if test.Want != got { + t.Fatalf("(-want +got)\n- %s\n+ %s", test.Want, got) + } + }) + } +} + +// drop returns a ReplaceAttr that drops the given keys. +func drop(keys ...string) func([]string, slog.Attr) slog.Attr { + return func(groups []string, a slog.Attr) slog.Attr { + if len(groups) > 0 { + return a + } + + for _, key := range keys { + if a.Key == key { + a = slog.Attr{} + } + } + return a + } +} + +func replace(new slog.Value, keys ...string) func([]string, slog.Attr) slog.Attr { + return func(groups []string, a slog.Attr) slog.Attr { + if len(groups) > 0 { + return a + } + + for _, key := range keys { + if a.Key == key { + a.Value = new + } + } + return a + } +} + +func TestReplaceAttr(t *testing.T) { + tests := [][]any{ + {}, + {"key", "val"}, + {"key", "val", slog.Group("group", "key2", "val2")}, + {"key", "val", slog.Group("group", "key2", "val2", slog.Group("group2", "key3", "val3"))}, + } + + type replaceAttrParams struct { + Groups []string + Attr slog.Attr + } + + replaceAttrRecorder := func(record *[]replaceAttrParams) func([]string, slog.Attr) slog.Attr { + return func(groups []string, a slog.Attr) slog.Attr { + *record = append(*record, replaceAttrParams{groups, a}) + return a + } + } + + for i, test := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + slogRecord := make([]replaceAttrParams, 0) + slogLogger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ + ReplaceAttr: replaceAttrRecorder(&slogRecord), + })) + slogLogger.Log(context.TODO(), slog.LevelInfo, "", test...) + + tintRecord := make([]replaceAttrParams, 0) + tintLogger := slog.New(tint.NewHandler(io.Discard, &tint.Options{ + ReplaceAttr: replaceAttrRecorder(&tintRecord), + })) + tintLogger.Log(context.TODO(), slog.LevelInfo, "", test...) + + if !slices.EqualFunc(slogRecord, tintRecord, func(a, b replaceAttrParams) bool { + return slices.Equal(a.Groups, b.Groups) && a.Attr.Equal(b.Attr) + }) { + t.Fatalf("(-want +got)\n- %v\n+ %v", slogRecord, tintRecord) + } + }) + } +} + +// See https://github.com/golang/exp/blob/master/slog/benchmarks/benchmarks_test.go#L25 +// +// Run e.g.: +// +// go test -bench=. -count=10 | benchstat -col /h /dev/stdin +func BenchmarkLogAttrs(b *testing.B) { + handler := []struct { + Name string + H slog.Handler + }{ + {"tint", tint.NewHandler(io.Discard, nil)}, + {"text", slog.NewTextHandler(io.Discard, nil)}, + {"json", slog.NewJSONHandler(io.Discard, nil)}, + {"discard", new(discarder)}, + } + + benchmarks := []struct { + Name string + F func(*slog.Logger) + }{ + { + "5 args", + func(logger *slog.Logger) { + logger.LogAttrs(context.TODO(), slog.LevelInfo, testMessage, + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + ) + }, + }, + { + "5 args custom level", + func(logger *slog.Logger) { + logger.LogAttrs(context.TODO(), slog.LevelInfo+1, testMessage, + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + ) + }, + }, + { + "10 args", + func(logger *slog.Logger) { + logger.LogAttrs(context.TODO(), slog.LevelInfo, testMessage, + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + ) + }, + }, + { + "40 args", + func(logger *slog.Logger) { + logger.LogAttrs(context.TODO(), slog.LevelInfo, testMessage, + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + slog.String("string", testString), + slog.Int("status", testInt), + slog.Duration("duration", testDuration), + slog.Time("time", testTime), + slog.Any("error", errTest), + ) + }, + }, + } + + for _, h := range handler { + b.Run("h="+h.Name, func(b *testing.B) { + for _, bench := range benchmarks { + b.Run(bench.Name, func(b *testing.B) { + b.ReportAllocs() + logger := slog.New(h.H) + for i := 0; i < b.N; i++ { + bench.F(logger) + } + }) + } + }) + } +} + +// discarder is a slog.Handler that discards all records. +type discarder struct{} + +func (*discarder) Enabled(context.Context, slog.Level) bool { return true } +func (*discarder) Handle(context.Context, slog.Record) error { return nil } +func (d *discarder) WithAttrs(attrs []slog.Attr) slog.Handler { return d } +func (d *discarder) WithGroup(name string) slog.Handler { return d } + +var ( + testMessage = "Test logging, but use a somewhat realistic message length." + testTime = time.Date(2022, time.May, 1, 0, 0, 0, 0, time.UTC) + testString = "7e3b3b2aaeff56a7108fe11e154200dd/7819479873059528190" + testInt = 32768 + testDuration = 23 * time.Second + errTest = errors.New("fail") +) diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index 462ee823f..39cc65103 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -27,7 +27,7 @@ import ( "github.com/neilotoole/sq/libsq/core/ioz/contextio" ) -// NewWriter returns an [io.Writer] that wraps w, is context-aware, and +// NewWriter returns a [progress.Writer] that wraps w, is context-aware, and // generates a progress bar as bytes are written to w. It is expected that ctx // contains a *progress.Progress, as returned by progress.FromContext. If not, // this function delegates to contextio.NewWriter: the returned writer will @@ -35,23 +35,30 @@ import ( // // Context state is checked BEFORE every Write. // -// The returned [io.Writer] implements [io.ReaderFrom] to allow [io.Copy] to select -// the best strategy while still checking the context state before every chunk transfer. +// The returned [progress.Writer] implements [io.ReaderFrom] to allow [io.Copy] +// to select the best strategy while still checking the context state before +// every chunk transfer. // -// The returned [io.Writer] also implements [io.Closer], even if the underlying -// writer does not. This is necessary because we need a means of stopping the -// progress bar when writing is complete. If the underlying writer does -// implement [io.Closer], it will be closed when the returned writer is closed. +// The returned [progress.Writer] also implements [io.Closer], even if the +// underlying writer does not. This is necessary because we need a means of +// stopping the progress bar when writing is complete. If the underlying writer +// does implement [io.Closer], it will be closed when the returned writer is +// closed. // -// If size is unknown, set to -1. -func NewWriter(ctx context.Context, msg string, size int64, w io.Writer) io.Writer { +// The caller is expected to close the returned writer, which results in the +// progress bar being removed. However, the progress bar can also be removed +// independently of closing the writer by invoking [Writer.Stop]. +// +// If size is unknown, set to -1; this will result in an indeterminate progress +// spinner instead of a bar. +func NewWriter(ctx context.Context, msg string, size int64, w io.Writer) Writer { if w, ok := w.(*progCopier); ok && ctx == w.ctx { return w } pb := FromContext(ctx) if pb == nil { - return contextio.NewWriter(ctx, w) + return writerWrapper{contextio.NewWriter(ctx, w)} } spinner := pb.NewByteCounter(msg, size) @@ -197,10 +204,53 @@ func (r *progReader) Read(p []byte) (n int, err error) { var _ io.ReaderFrom = (*progCopier)(nil) +// Writer is an [io.WriteCloser] as returned by [NewWriter]. +type Writer interface { + io.WriteCloser + + // Stop stops and removes the progress bar. Typically this is accomplished + // by invoking Writer.Close, but there are circumstances where it may + // be desirable to stop the progress bar without closing the underlying + // writer. + Stop() +} + +var _ Writer = (*writerWrapper)(nil) + +// writerWrapper wraps an io.Writer to implement [progress.Writer]. +type writerWrapper struct { + io.Writer +} + +// Close implements [io.WriteCloser]. If the underlying +// writer implements [io.Closer], it will be closed. +func (w writerWrapper) Close() error { + if c, ok := w.Writer.(io.Closer); ok { + return c.Close() + } + return nil +} + +// Stop implements [Writer] and is no-op. +func (w writerWrapper) Stop() { + return +} + +var _ Writer = (*progCopier)(nil) + type progCopier struct { progWriter } +// Stop implements [progress.Writer]. +func (w *progCopier) Stop() { + if w == nil || w.spinner == nil { + return + } + + w.spinner.Stop() +} + // ReadFrom implements interface [io.ReaderFrom], but with context awareness. // // This should allow efficient copying allowing writer or reader to define the chunk size. diff --git a/libsq/source/files.go b/libsq/source/files.go index 10269cb54..c4fb01f70 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -98,31 +98,39 @@ func (fs *Files) Size(ctx context.Context, src *Source) (size int64, err error) // AddStdin copies f to fs's cache: the stdin data in f // is later accessible via fs.Open(src) where src.Handle -// is StdinHandle; f's type can be detected via TypeStdin. +// is StdinHandle; f's type can be detected via DetectStdinType. // Note that f is closed by this method. -// -// REVISIT: it's possible we'll ditch AddStdin and TypeStdin -// in some future version; this mechanism is a stopgap. func (fs *Files) AddStdin(ctx context.Context, f *os.File) error { + //return fs.AddStdinOld(ctx, f) fs.mu.Lock() defer fs.mu.Unlock() - // We don't need r, but we're responsible for closing it. - r, err := fs.addFile(ctx, f, StdinHandle) // f is closed by addFile - if err != nil { - return err - } + err := fs.addStdinViaCopyAsync(ctx, f) // f is closed by addFile + return errz.Wrap(err, "failed to read stdin") - return r.Close() } -// TypeStdin detects the type of stdin as previously added +// +//func (fs *Files) addStdinOld(ctx context.Context, f *os.File) error { +// fs.mu.Lock() +// defer fs.mu.Unlock() +// +// // We don't need r, but we're responsible for closing it. +// r, err := fs.addFileOld(ctx, f, StdinHandle) // f is closed by addFile +// if err != nil { +// return err +// } +// +// return r.Close() +//} + +// DetectStdinType detects the type of stdin as previously added // by AddStdin. An error is returned if AddStdin was not // first invoked. If the type cannot be detected, TypeNone and // nil are returned. -func (fs *Files) TypeStdin(ctx context.Context) (drivertype.Type, error) { +func (fs *Files) DetectStdinType(ctx context.Context) (drivertype.Type, error) { if !fs.fcache.Exists(StdinHandle) { - return drivertype.None, errz.New("must invoke AddStdin before invoking TypeStdin") + return drivertype.None, errz.New("must invoke Files.AddStdin before invoking DetectStdinType") } typ, ok, err := fs.detectType(ctx, StdinHandle) @@ -137,39 +145,192 @@ func (fs *Files) TypeStdin(ctx context.Context) (drivertype.Type, error) { return typ, nil } -// add file copies f to fs's cache, returning a reader which the -// caller is responsible for closing. f is closed by this method. -func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { +func (fs *Files) addStdinViaCopyAsync(ctx context.Context, f *os.File) error { log := lg.FromContext(ctx) - log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) + // Special handling for stdin + r, w, wErrFn, err := fs.fcache.GetWithErr(StdinHandle) + if err != nil { + return errz.Err(err) + } - if key != StdinHandle { - if fs.fcache.Exists(key) { - return nil, errz.Errorf("file already exists in cache: %s", key) - } + defer lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - if err := fs.fcache.MapFile(f.Name()); err != nil { - return nil, errz.Wrapf(err, "failed to map file into fscache: %s", f.Name()) + if w == nil { + return errz.Errorf("fscache: no writer for %s", StdinHandle) + } + + df := ioz.DelayReader(f, time.Microsecond*500, true) // FIXME: Delete + cr := contextio.NewReader(ctx, df) + pw := progress.NewWriter(ctx, "Reading stdin", -1, w) + + start := time.Now() + ioz.CopyAsync(pw, cr, func(written int64, err error) { + log.Debug("Async stdin cache fill: callback received") + elapsed := time.Since(start) + if err == nil { + log.Debug("Async stdin cache fill: completed", "copied", written, "elapsed", elapsed) + lg.WarnIfCloseError(log, "Close stdin cache", w) + lg.WarnIfCloseError(log, "Close stdin reader file", f) + pw.Stop() + return } - r, _, err := fs.fcache.Get(key) - return r, errz.Err(err) - } + log.Error("Async stdin cache fill: failure", lga.Err, err, "copied", written, "elapsed", elapsed) + pw.Stop() + wErrFn(err) + // We deliberately don't close w here, because wErrFn handles that work. + }) + log.Debug("Async stdin cache fill: dispatched") + // + //copier := fscache.AsyncFiller{ + // Message: "Reading source data", + // Log: lg.FromContext(ctx).With(lga.Action, "Cache fill"), + // NewContextWriterFn: progress.NewWriter, + // // We don't use progress.NewReader here, because that + // // would result in double counting of bytes transferred. + // NewContextReaderFn: func(ctx context.Context, msg string, size int64, r io.Reader) io.Reader { + // return contextio.NewReader(ctx, r) + // }, + // CloseReader: true, + //} + // + //df := ioz.DelayReader(f, time.Millisecond, true) // FIXME: Delete + //if err = copier.Copy(ctx, -1, w, df); err != nil { + // return errz.Err(err) + //} + + return nil +} - // Special handling for stdin - r, w, err := fs.fcache.Get(StdinHandle) +// +//func (fs *Files) addStdinOld(ctx context.Context, f *os.File) error { +// // Special handling for stdin +// r, w, err := fs.fcache.Get(StdinHandle) +// if err != nil { +// return errz.Err(err) +// } +// +// defer lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) +// +// if w == nil { +// return errz.Errorf("fscache: no writer for %s", StdinHandle) +// } +// +// copier := fscache.AsyncFiller{ +// Message: "Reading source data", +// Log: lg.FromContext(ctx).With(lga.Action, "Cache fill"), +// NewContextWriterFn: progress.NewWriter, +// // We don't use progress.NewReader here, because that +// // would result in double counting of bytes transferred. +// NewContextReaderFn: func(ctx context.Context, msg string, size int64, r io.Reader) io.Reader { +// return contextio.NewReader(ctx, r) +// }, +// CloseReader: true, +// } +// +// df := ioz.DelayReader(f, time.Millisecond, true) // FIXME: Delete +// if err = copier.Copy(ctx, -1, w, df); err != nil { +// return errz.Err(err) +// } +// +// return nil +//} +// +//// add file copies f to fs's cache, returning a reader which the +//// caller is responsible for closing. f is closed by this method. +//func (fs *Files) addFileOld(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { +// log := lg.FromContext(ctx) +// log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) +// +// r, w, err := fs.fcache.Get(key) +// if err != nil { +// return nil, errz.Err(err) +// } +// if w == nil { +// lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) +// return nil, errz.Errorf("failed to add to fscache (possibly previously added): %s", key) +// } +// +// fi, err := f.Stat() +// if err != nil { +// return nil, errz.Err(err) +// } +// size := fi.Size() +// if size == 0 { +// size = -1 +// } +// +// // TODO: Problematically, we copy the entire contents of f into fscache. +// // This is probably necessary for piped data on stdin, but for files +// // that already exist on the file system, it would be nice if the cacheFile +// // could be mapped directly to the filesystem file. This might require +// // hacking on the fscache impl. +// +// contextFn := func(ctx context.Context, msg string, total int64, w io.Writer) io.Writer { +// return progress.NewWriter(ctx, msg, total, w) +// } +// +// copier := fscache.AsyncFiller{ +// Message: "Reading source data", +// Log: log.With(lga.Action, "Cache fill"), +// //NewContextWriterFn: progress.NewWriter, +// NewContextWriterFn: contextFn, +// // We don't use progress.NewReader here, because that +// // would result in double counting of bytes transferred. +// NewContextReaderFn: func(ctx context.Context, msg string, size int64, r io.Reader) io.Reader { +// return contextio.NewReader(ctx, r) +// }, +// CloseReader: true, +// } +// +// df := ioz.DelayReader(f, time.Millisecond, true) // FIXME: Delete +// if err = copier.Copy(ctx, size, w, df); err != nil { +// lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) +// return nil, errz.Err(err) +// } +// +// return r, nil +//} + +// add file copies f to fs's cache, returning a reader which the +// caller is responsible for closing. f is closed by this method. +func (fs *Files) addFileViaAsyncFiller(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { + log := lg.FromContext(ctx) + log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) + + r, w, err := fs.fcache.Get(key) if err != nil { return nil, errz.Err(err) } if w == nil { lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - return nil, errz.Errorf("fscache: no writer for %s", StdinHandle) + return nil, errz.Errorf("failed to add to fscache (possibly previously added): %s", key) + } + + fi, err := f.Stat() + if err != nil { + return nil, errz.Err(err) + } + size := fi.Size() + if size == 0 { + size = -1 + } + + // TODO: Problematically, we copy the entire contents of f into fscache. + // This is probably necessary for piped data on stdin, but for files + // that already exist on the file system, it would be nice if the cacheFile + // could be mapped directly to the filesystem file. This might require + // hacking on the fscache impl. + + contextFn := func(ctx context.Context, msg string, total int64, w io.Writer) io.Writer { + return progress.NewWriter(ctx, msg, total, w) } copier := fscache.AsyncFiller{ - Message: "Reading source data", - Log: log.With(lga.Action, "Cache fill"), - NewContextWriterFn: progress.NewWriter, + Message: "Reading source data", + Log: log.With(lga.Action, "Cache fill"), + //NewContextWriterFn: progress.NewWriter, + NewContextWriterFn: contextFn, // We don't use progress.NewReader here, because that // would result in double counting of bytes transferred. NewContextReaderFn: func(ctx context.Context, msg string, size int64, r io.Reader) io.Reader { @@ -179,14 +340,88 @@ func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.R } df := ioz.DelayReader(f, time.Millisecond, true) // FIXME: Delete - if err = copier.Copy(ctx, -1, w, df); err != nil { + if err = copier.Copy(ctx, size, w, df); err != nil { + lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) + return nil, errz.Err(err) + } + + return r, nil +} + +// add file copies f to fs's cache, returning a reader which the +// caller is responsible for closing. f is closed by this method. +func (fs *Files) addFileViaCopyAsync(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { + log := lg.FromContext(ctx) + log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) + + r, w, wErrFn, err := fs.fcache.GetWithErr(key) + if err != nil { + return nil, errz.Err(err) + } + if w == nil { lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) + return nil, errz.Errorf("failed to add to fscache (possibly previously added): %s", key) + } + + fi, err := f.Stat() + if err != nil { return nil, errz.Err(err) } + size := fi.Size() + if size == 0 { + size = -1 + } + + df := ioz.DelayReader(f, time.Microsecond*500, true) // FIXME: Delete + cr := contextio.NewReader(ctx, df) + pw := progress.NewWriter(ctx, "Reading stdin", -1, w) + + start := time.Now() + ioz.CopyAsync(pw, cr, func(written int64, err error) { + log.Debug("Async stdin cache fill: callback received") + elapsed := time.Since(start) + if err == nil { + log.Debug("Async stdin cache fill: completed", "copied", written, "elapsed", elapsed) + lg.WarnIfCloseError(log, "Close stdin cache", w) + lg.WarnIfCloseError(log, "Close stdin reader file", f) + pw.Stop() + return + } + log.Error("Async stdin cache fill: failure", lga.Err, err, "copied", written, "elapsed", elapsed) + pw.Stop() + wErrFn(err) + // We deliberately don't close w here, because wErrFn handles that work. + }) + log.Debug("Async stdin cache fill: dispatched") return r, nil } +// add file copies f to fs's cache, returning a reader which the +// caller is responsible for closing. f is closed by this method. +// Do not add stdin via this function; instead use addStdinViaCopyAsync. +func (fs *Files) addFileViaMapFile(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { + log := lg.FromContext(ctx) + log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) + + if key == StdinHandle { + // This is a programming error; the caller should have + // instead invoked addStdinViaCopyAsync. Probably should panic here. + return nil, errz.New("illegal to add stdin via Files.addFile") + } + + if fs.fcache.Exists(key) { + return nil, errz.Errorf("file already exists in cache: %s", key) + } + + if err := fs.fcache.MapFile(f.Name()); err != nil { + return nil, errz.Wrapf(err, "failed to map file into fscache: %s", f.Name()) + } + + r, _, err := fs.fcache.Get(key) + return r, errz.Err(err) +} + // Filepath returns the file path of src.Location. // An error is returned the source's driver type // is not a file type (i.e. it is a SQL driver). @@ -217,6 +452,7 @@ func (fs *Files) Open(ctx context.Context, src *Source) (io.ReadCloser, error) { fs.mu.Lock() defer fs.mu.Unlock() + lg.FromContext(ctx).Debug("Files.Open", lga.Src, src) return fs.newReader(ctx, src.Location) } @@ -248,8 +484,11 @@ func (fs *Files) ReadAll(ctx context.Context, src *Source) ([]byte, error) { } func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, error) { + log := lg.FromContext(ctx).With(lga.Loc, loc) + log.Debug("Files.newReader", lga.Loc, loc) if loc == StdinHandle { r, w, err := fs.fcache.Get(StdinHandle) + log.Debug("Returned from fs.fcache.Get", lga.Err, err) if err != nil { return nil, errz.Err(err) } @@ -268,7 +507,10 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro } // Note that addFile closes f - r, err := fs.addFile(ctx, f, loc) + //r, err := fs.addFileViaMapFile(ctx, f, loc) + //r, err := fs.addFileViaAsyncFiller(ctx, f, loc) + r, err := fs.addFileViaCopyAsync(ctx, f, loc) + //r, err := fs.addFileOld(ctx, f, loc) if err != nil { return nil, err } @@ -415,6 +657,9 @@ func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, e } func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Type, ok bool, err error) { + log := lg.FromContext(ctx).With(lga.Loc, loc) + start := time.Now() + log.Debug("Files.detectType") if len(fs.detectFns) == 0 { return drivertype.None, false, nil } @@ -439,12 +684,25 @@ func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Typ } g, gCtx := errgroup.WithContext(ctx) - gCtx = lg.NewContext(gCtx, fs.log) + //gCtx = lg.NewContext(gCtx, fs.log) - for _, detectFn := range fs.detectFns { + for i, detectFn := range fs.detectFns { + i := i detectFn := detectFn g.Go(func() error { + start := time.Now() + defer func() { + lg.FromContext(gCtx).Debug( + "detectType: detectFn complete", + lga.Type, + detectFn, + lga.Index, i, + lga.Elapsed, + time.Since(start), + ) + + }() select { case <-gCtx.Done(): return gCtx.Err() @@ -480,9 +738,11 @@ func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Typ const detectScoreThreshold = 0.5 if highestScore >= detectScoreThreshold { + log.Debug("Type detected", lga.Type, typ, lga.Elapsed, time.Since(start)) return typ, true, nil } + log.Warn("No type detected", lga.Type, typ, lga.Elapsed, time.Since(start)) return drivertype.None, false, nil } diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index fae2e2828..bc2fd9d2e 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -205,7 +205,7 @@ func TestFiles_Stdin(t *testing.T) { err = fs.AddStdin(th.Context, f) // f is closed by AddStdin require.NoError(t, err) - typ, err := fs.TypeStdin(th.Context) + typ, err := fs.DetectStdinType(th.Context) if tc.wantErr { require.Error(t, err) return @@ -220,7 +220,7 @@ func TestFiles_Stdin_ErrorWrongOrder(t *testing.T) { th := testh.New(t) fs := th.Files() - typ, err := fs.TypeStdin(th.Context) + typ, err := fs.DetectStdinType(th.Context) require.Error(t, err, "should error because AddStdin not yet invoked") require.Equal(t, drivertype.None, typ) @@ -228,7 +228,7 @@ func TestFiles_Stdin_ErrorWrongOrder(t *testing.T) { require.NoError(t, err) require.NoError(t, fs.AddStdin(th.Context, f)) // AddStdin closes f - typ, err = fs.TypeStdin(th.Context) + typ, err = fs.DetectStdinType(th.Context) require.NoError(t, err) require.Equal(t, csv.TypeCSV, typ) } From d8620744afd7a88efa03ab9a9585119449daec32 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 2 Dec 2023 08:12:34 -0700 Subject: [PATCH 031/195] wip --- libsq/core/lg/devlog/devlog.go | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/libsq/core/lg/devlog/devlog.go b/libsq/core/lg/devlog/devlog.go index 800420cab..bceac7f4e 100644 --- a/libsq/core/lg/devlog/devlog.go +++ b/libsq/core/lg/devlog/devlog.go @@ -4,20 +4,12 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/devlog/tint" "io" "log/slog" - "path/filepath" - "strconv" - "strings" ) const shortTimeFormat = `15:04:05.000000` -// New returns a developer-friendly logger that +// NewHandler returns a developer-friendly slog.Handler that // logs to w. -func New(w io.Writer, lvl slog.Leveler) *slog.Logger { - h := NewHandler(w, lvl) - return slog.New(h) -} - func NewHandler(w io.Writer, lvl slog.Leveler) slog.Handler { h := tint.NewHandler(w, &tint.Options{ Level: lvl, @@ -34,22 +26,3 @@ func NewHandler(w io.Writer, lvl slog.Leveler) slog.Handler { }) return h } - -// replaceSourceShort prints a dev-friendly "source" field. -func replaceSourceShort(_ []string, a slog.Attr) slog.Attr { - if src, ok := a.Value.Any().(*slog.Source); ok { - s := filepath.Join(filepath.Base(filepath.Dir(src.File)), filepath.Base(src.File)) - s += ":" + strconv.Itoa(src.Line) - - fn := src.Function - parts := strings.Split(src.Function, "/") - if len(parts) > 0 { - fn = parts[len(parts)-1] - } - - s += ":" + fn - //a.Key = "src" - a.Value = slog.StringValue(s) - } - return a -} From 1a3cb2ad923a4d6457efb0417d7e3a98b2df9cb0 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 2 Dec 2023 17:48:02 -0700 Subject: [PATCH 032/195] lots of working parts --- libsq/core/ioz/contextio/contextio_test.go | 2 +- libsq/core/ioz/ioz.go | 59 ++- libsq/core/lg/lga/lga.go | 2 + libsq/core/progress/progress_test.go | 2 +- libsq/core/record/meta.go | 2 +- libsq/source/cache.go | 53 +++ libsq/source/files.go | 488 +++++---------------- libsq/source/location.go | 33 ++ 8 files changed, 264 insertions(+), 377 deletions(-) create mode 100644 libsq/source/cache.go diff --git a/libsq/core/ioz/contextio/contextio_test.go b/libsq/core/ioz/contextio/contextio_test.go index 97dbfb273..56b83db95 100644 --- a/libsq/core/ioz/contextio/contextio_test.go +++ b/libsq/core/ioz/contextio/contextio_test.go @@ -30,7 +30,7 @@ func TestNewWriter_Closer(t *testing.T) { assert.False(t, isCloser, "expected reader NOT to be io.WriteCloser, but was %T", gotWriter) - bufCloser := ioz.ToWriteCloser(buf) + bufCloser := ioz.WriteCloser(buf) gotWriter = contextio.NewWriter(ctx, bufCloser) require.NotNil(t, gotWriter) _, isCloser = gotWriter.(io.WriteCloser) diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index aa5d8f698..4dd13c63b 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -11,6 +11,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "time" yaml "github.com/goccy/go-yaml" @@ -323,13 +324,13 @@ func (w *notifyOnceWriter) Write(p []byte) (n int, err error) { return w.w.Write(p) } -// ToWriteCloser returns w as an io.WriteCloser. If w implements +// WriteCloser returns w as an io.WriteCloser. If w implements // io.WriteCloser, w is returned. Otherwise, w is wrapped in a // no-op decorator that implements io.WriteCloser. // -// ToWriteCloser is the missing sibling of io.NopCloser, which +// WriteCloser is the missing sibling of io.NopCloser, which // isn't implemented in stdlib. See: https://github.com/golang/go/issues/22823. -func ToWriteCloser(w io.Writer) io.WriteCloser { +func WriteCloser(w io.Writer) io.WriteCloser { if wc, ok := w.(io.WriteCloser); ok { return wc } @@ -358,3 +359,55 @@ func (nopWriteCloserReaderFrom) Close() error { return nil } func (c nopWriteCloserReaderFrom) ReadFrom(r io.Reader) (int64, error) { return c.Writer.(io.ReaderFrom).ReadFrom(r) } + +// NewWrittenWriter returns a writer that counts the number of bytes +// written to the underlying writer. The number of bytes written can +// be obtained via [WrittenWriter.Written], which blocks until writing +// has concluded. +func NewWrittenWriter(w io.WriteCloser) *WrittenWriter { + return &WrittenWriter{ + w: w, + c: &atomic.Int64{}, + done: make(chan struct{}), + } +} + +var _ io.Writer = (*WrittenWriter)(nil) + +// WrittenWriter is an io.WriteCloser that counts the number of bytes +// written to the underlying writer. The number of bytes written can +// be obtained via [WrittenWriter.Written], which blocks until writing +// has concluded. +type WrittenWriter struct { + c *atomic.Int64 + w io.WriteCloser + doneOnce sync.Once + done chan struct{} +} + +// Written returns the number of bytes written to the underlying +// writer, blocking until writing concludes, either via invocation of +// Close, or via an error in Write. +func (w *WrittenWriter) Written() int64 { + select { + case <-w.done: + return w.c.Load() + } +} + +// Close implements io.WriteCloser. +func (w *WrittenWriter) Close() error { + closeErr := w.w.Close() + w.doneOnce.Do(func() { close(w.done) }) + return closeErr +} + +// Write implements io.WriterCloser. +func (w *WrittenWriter) Write(p []byte) (n int, err error) { + n, err = w.w.Write(p) + w.c.Add(int64(n)) + if err != nil { + w.doneOnce.Do(func() { close(w.done) }) + } + return n, err +} diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index 02d806933..eed454dea 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -13,6 +13,7 @@ const ( Count = "count" Commit = "commit" Conn = "conn" + Copied = "copied" Cleanup = "cleanup" DB = "db" DBType = "db_type" @@ -23,6 +24,7 @@ const ( Env = "env" Err = "error" Expected = "expected" + File = "file" From = "from" Flag = "flag" Handle = "handle" diff --git a/libsq/core/progress/progress_test.go b/libsq/core/progress/progress_test.go index f1dc8ef12..7284f5ad0 100644 --- a/libsq/core/progress/progress_test.go +++ b/libsq/core/progress/progress_test.go @@ -55,7 +55,7 @@ func TestNewWriter_Closer(t *testing.T) { assert.True(t, isCloser, "expected writer to be io.WriteCloser, but was %T", gotWriter) - bufCloser := ioz.ToWriteCloser(buf) + bufCloser := ioz.WriteCloser(buf) gotWriter = progress.NewWriter(ctx, "no closer", -1, bufCloser) require.NotNil(t, gotWriter) _, isCloser = gotWriter.(io.WriteCloser) diff --git a/libsq/core/record/meta.go b/libsq/core/record/meta.go index 3c5587030..ad3e44bea 100644 --- a/libsq/core/record/meta.go +++ b/libsq/core/record/meta.go @@ -84,7 +84,7 @@ func (fm *FieldMeta) MungedName() string { return fm.mungedName } -// Length is documented by sql.ColumnType.Length. +// Length is documented by sql.ColumnType.Written. func (fm *FieldMeta) Length() (length int64, ok bool) { return fm.data.Length, fm.data.HasLength } diff --git a/libsq/source/cache.go b/libsq/source/cache.go new file mode 100644 index 000000000..3404616ce --- /dev/null +++ b/libsq/source/cache.go @@ -0,0 +1,53 @@ +package source + +import ( + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/stringz" + "os" + "path/filepath" +) + +// CacheDirFor gets the cache dir for handle, creating it if necessary. +// If handle is empty or invalid, a random value is generated. +func CacheDirFor(src *Source) (dir string, err error) { + handle := src.Handle + switch handle { + case "": + // FIXME: This is surely an error? + return "", errz.Errorf("open cache dir: empty handle") + //handle = "@cache_" + stringz.UniqN(32) + case StdinHandle: + // stdin is different input every time, so we need a unique + // cache dir. In practice, stdin probably isn't using this function. + handle += "_" + stringz.UniqN(32) + default: + if err = ValidHandle(handle); err != nil { + return "", errz.Wrapf(err, "open cache dir: invalid handle: %s", handle) + } + } + + dir = CacheDirPath() + sanitized := Handle2SafePath(handle) + hash := src.Hash() + dir = filepath.Join(dir, "sources", sanitized, hash) + if err = os.MkdirAll(dir, 0o750); err != nil { + return "", errz.Wrapf(err, "open cache dir: %s", dir) + } + + return dir, nil +} + +// CacheDirPath returns the sq cache dir. This is generally +// in USER_CACHE_DIR/sq/cache, but could also be in TEMP_DIR/sq/cache +// or similar. It is not guaranteed that the returned dir exists +// or is accessible. +func CacheDirPath() (dir string) { + var err error + if dir, err = os.UserCacheDir(); err != nil { + // Some systems may not have a user cache dir, so we fall back + // to the system temp dir. + dir = os.TempDir() + } + dir = filepath.Join(dir, "sq", "cache") + return dir +} diff --git a/libsq/source/files.go b/libsq/source/files.go index c4fb01f70..ade8e010e 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -7,7 +7,6 @@ import ( "mime" "net/url" "os" - "path/filepath" "sync" "time" @@ -25,7 +24,6 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/libsq/source/fetcher" ) @@ -48,6 +46,11 @@ type Files struct { clnup *cleanup.Cleanup fcache *fscache.FSCache detectFns []DriverDetectFunc + + // stdinLength is a func that returns number of bytes read from stdin. + // It is nil if stdin has not been read. The func may block until reading + // of stdin has completed. + stdinLength func() int64 } // NewFiles returns a new Files instance. @@ -74,56 +77,6 @@ func (fs *Files) AddDriverDetectors(detectFns ...DriverDetectFunc) { fs.detectFns = append(fs.detectFns, detectFns...) } -// Size returns the file size of src.Location. This exists -// as a convenience function and something of a replacement -// for using os.Stat to get the file size. -// -// FIXME: This is a terrible way to get the size. It currently -// reads all the bytes. Awful. -func (fs *Files) Size(ctx context.Context, src *Source) (size int64, err error) { - r, err := fs.Open(ctx, src) - if err != nil { - return 0, err - } - - defer lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - - size, err = io.Copy(io.Discard, r) - if err != nil { - return 0, errz.Err(err) - } - - return size, nil -} - -// AddStdin copies f to fs's cache: the stdin data in f -// is later accessible via fs.Open(src) where src.Handle -// is StdinHandle; f's type can be detected via DetectStdinType. -// Note that f is closed by this method. -func (fs *Files) AddStdin(ctx context.Context, f *os.File) error { - //return fs.AddStdinOld(ctx, f) - fs.mu.Lock() - defer fs.mu.Unlock() - - err := fs.addStdinViaCopyAsync(ctx, f) // f is closed by addFile - return errz.Wrap(err, "failed to read stdin") - -} - -// -//func (fs *Files) addStdinOld(ctx context.Context, f *os.File) error { -// fs.mu.Lock() -// defer fs.mu.Unlock() -// -// // We don't need r, but we're responsible for closing it. -// r, err := fs.addFileOld(ctx, f, StdinHandle) // f is closed by addFile -// if err != nil { -// return err -// } -// -// return r.Close() -//} - // DetectStdinType detects the type of stdin as previously added // by AddStdin. An error is returned if AddStdin was not // first invoked. If the type cannot be detected, TypeNone and @@ -145,268 +98,117 @@ func (fs *Files) DetectStdinType(ctx context.Context) (drivertype.Type, error) { return typ, nil } -func (fs *Files) addStdinViaCopyAsync(ctx context.Context, f *os.File) error { - log := lg.FromContext(ctx) - // Special handling for stdin - r, w, wErrFn, err := fs.fcache.GetWithErr(StdinHandle) - if err != nil { - return errz.Err(err) - } - - defer lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - - if w == nil { - return errz.Errorf("fscache: no writer for %s", StdinHandle) - } - - df := ioz.DelayReader(f, time.Microsecond*500, true) // FIXME: Delete - cr := contextio.NewReader(ctx, df) - pw := progress.NewWriter(ctx, "Reading stdin", -1, w) - - start := time.Now() - ioz.CopyAsync(pw, cr, func(written int64, err error) { - log.Debug("Async stdin cache fill: callback received") - elapsed := time.Since(start) - if err == nil { - log.Debug("Async stdin cache fill: completed", "copied", written, "elapsed", elapsed) - lg.WarnIfCloseError(log, "Close stdin cache", w) - lg.WarnIfCloseError(log, "Close stdin reader file", f) - pw.Stop() - return +// Size returns the file size of src.Location. If the source is being +// loaded asynchronously, this function may block until loading completes. +func (fs *Files) Size(ctx context.Context, src *Source) (size int64, err error) { + locTyp := getLocType(src.Location) + switch locTyp { + case locTypeLocalFile: + // It's a filepath + var fi os.FileInfo + if fi, err = os.Stat(src.Location); err != nil { + return 0, errz.Err(err) + } + return fi.Size(), nil + case locTypeRemoteFile: + // FIXME: implement remote file size. + return 0, errz.Errorf("remote file size not implemented: %s", src.Location) + case locTypeSQL: + return 0, errz.Errorf("cannot get size of SQL source: %s", src.Handle) + case locTypeStdin: + // Special handling for stdin. + if fs.stdinLength == nil { + return 0, errz.Errorf("stdin not yet added") + } + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + return fs.stdinLength(), nil } + default: + return 0, errz.Errorf("unknown source location type: %s", RedactLocation(src.Location)) + } +} - log.Error("Async stdin cache fill: failure", lga.Err, err, "copied", written, "elapsed", elapsed) - pw.Stop() - wErrFn(err) - // We deliberately don't close w here, because wErrFn handles that work. - }) - log.Debug("Async stdin cache fill: dispatched") - // - //copier := fscache.AsyncFiller{ - // Message: "Reading source data", - // Log: lg.FromContext(ctx).With(lga.Action, "Cache fill"), - // NewContextWriterFn: progress.NewWriter, - // // We don't use progress.NewReader here, because that - // // would result in double counting of bytes transferred. - // NewContextReaderFn: func(ctx context.Context, msg string, size int64, r io.Reader) io.Reader { - // return contextio.NewReader(ctx, r) - // }, - // CloseReader: true, - //} - // - //df := ioz.DelayReader(f, time.Millisecond, true) // FIXME: Delete - //if err = copier.Copy(ctx, -1, w, df); err != nil { - // return errz.Err(err) - //} +// AddStdin copies f to fs's cache: the stdin data in f +// is later accessible via fs.Open(src) where src.Handle +// is StdinHandle; f's type can be detected via DetectStdinType. +// Note that f is closed by this method. +func (fs *Files) AddStdin(ctx context.Context, f *os.File) error { + fs.mu.Lock() + defer fs.mu.Unlock() - return nil + err := fs.addStdin(ctx, f) // f is closed by addStdin + return errz.Wrap(err, "failed to read stdin") } -// -//func (fs *Files) addStdinOld(ctx context.Context, f *os.File) error { -// // Special handling for stdin -// r, w, err := fs.fcache.Get(StdinHandle) -// if err != nil { -// return errz.Err(err) -// } -// -// defer lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) -// -// if w == nil { -// return errz.Errorf("fscache: no writer for %s", StdinHandle) -// } -// -// copier := fscache.AsyncFiller{ -// Message: "Reading source data", -// Log: lg.FromContext(ctx).With(lga.Action, "Cache fill"), -// NewContextWriterFn: progress.NewWriter, -// // We don't use progress.NewReader here, because that -// // would result in double counting of bytes transferred. -// NewContextReaderFn: func(ctx context.Context, msg string, size int64, r io.Reader) io.Reader { -// return contextio.NewReader(ctx, r) -// }, -// CloseReader: true, -// } -// -// df := ioz.DelayReader(f, time.Millisecond, true) // FIXME: Delete -// if err = copier.Copy(ctx, -1, w, df); err != nil { -// return errz.Err(err) -// } -// -// return nil -//} -// -//// add file copies f to fs's cache, returning a reader which the -//// caller is responsible for closing. f is closed by this method. -//func (fs *Files) addFileOld(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { -// log := lg.FromContext(ctx) -// log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) -// -// r, w, err := fs.fcache.Get(key) -// if err != nil { -// return nil, errz.Err(err) -// } -// if w == nil { -// lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) -// return nil, errz.Errorf("failed to add to fscache (possibly previously added): %s", key) -// } -// -// fi, err := f.Stat() -// if err != nil { -// return nil, errz.Err(err) -// } -// size := fi.Size() -// if size == 0 { -// size = -1 -// } -// -// // TODO: Problematically, we copy the entire contents of f into fscache. -// // This is probably necessary for piped data on stdin, but for files -// // that already exist on the file system, it would be nice if the cacheFile -// // could be mapped directly to the filesystem file. This might require -// // hacking on the fscache impl. -// -// contextFn := func(ctx context.Context, msg string, total int64, w io.Writer) io.Writer { -// return progress.NewWriter(ctx, msg, total, w) -// } -// -// copier := fscache.AsyncFiller{ -// Message: "Reading source data", -// Log: log.With(lga.Action, "Cache fill"), -// //NewContextWriterFn: progress.NewWriter, -// NewContextWriterFn: contextFn, -// // We don't use progress.NewReader here, because that -// // would result in double counting of bytes transferred. -// NewContextReaderFn: func(ctx context.Context, msg string, size int64, r io.Reader) io.Reader { -// return contextio.NewReader(ctx, r) -// }, -// CloseReader: true, -// } -// -// df := ioz.DelayReader(f, time.Millisecond, true) // FIXME: Delete -// if err = copier.Copy(ctx, size, w, df); err != nil { -// lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) -// return nil, errz.Err(err) -// } -// -// return r, nil -//} - -// add file copies f to fs's cache, returning a reader which the -// caller is responsible for closing. f is closed by this method. -func (fs *Files) addFileViaAsyncFiller(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { - log := lg.FromContext(ctx) - log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) +// addStdin synchronously copies f (stdin) to fs's cache. f is closed +// when the async copy completes. This method should only be used +// for stdin; for regular files, use Files.addFile. +func (fs *Files) addStdin(ctx context.Context, f *os.File) error { + log := lg.FromContext(ctx).With(lga.File, f.Name()) - r, w, err := fs.fcache.Get(key) - if err != nil { - return nil, errz.Err(err) - } - if w == nil { - lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - return nil, errz.Errorf("failed to add to fscache (possibly previously added): %s", key) + if fs.stdinLength != nil { + return errz.Errorf("stdin already added") } - fi, err := f.Stat() + // Special handling for stdin + r, w, wErrFn, err := fs.fcache.GetWithErr(StdinHandle) if err != nil { - return nil, errz.Err(err) - } - size := fi.Size() - if size == 0 { - size = -1 - } - - // TODO: Problematically, we copy the entire contents of f into fscache. - // This is probably necessary for piped data on stdin, but for files - // that already exist on the file system, it would be nice if the cacheFile - // could be mapped directly to the filesystem file. This might require - // hacking on the fscache impl. - - contextFn := func(ctx context.Context, msg string, total int64, w io.Writer) io.Writer { - return progress.NewWriter(ctx, msg, total, w) - } - - copier := fscache.AsyncFiller{ - Message: "Reading source data", - Log: log.With(lga.Action, "Cache fill"), - //NewContextWriterFn: progress.NewWriter, - NewContextWriterFn: contextFn, - // We don't use progress.NewReader here, because that - // would result in double counting of bytes transferred. - NewContextReaderFn: func(ctx context.Context, msg string, size int64, r io.Reader) io.Reader { - return contextio.NewReader(ctx, r) - }, - CloseReader: true, - } - - df := ioz.DelayReader(f, time.Millisecond, true) // FIXME: Delete - if err = copier.Copy(ctx, size, w, df); err != nil { - lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - return nil, errz.Err(err) + return errz.Err(err) } - return r, nil -} - -// add file copies f to fs's cache, returning a reader which the -// caller is responsible for closing. f is closed by this method. -func (fs *Files) addFileViaCopyAsync(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { - log := lg.FromContext(ctx) - log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - r, w, wErrFn, err := fs.fcache.GetWithErr(key) - if err != nil { - return nil, errz.Err(err) - } if w == nil { - lg.WarnIfCloseError(fs.log, lgm.CloseFileReader, r) - return nil, errz.Errorf("failed to add to fscache (possibly previously added): %s", key) + // Shouldn't happen + return errz.Errorf("no cache writer for %s", StdinHandle) } - fi, err := f.Stat() - if err != nil { - return nil, errz.Err(err) - } - size := fi.Size() - if size == 0 { - size = -1 - } + lw := ioz.NewWrittenWriter(w) + fs.stdinLength = lw.Written df := ioz.DelayReader(f, time.Microsecond*500, true) // FIXME: Delete cr := contextio.NewReader(ctx, df) - pw := progress.NewWriter(ctx, "Reading stdin", -1, w) + pw := progress.NewWriter(ctx, "Reading stdin", -1, lw) start := time.Now() ioz.CopyAsync(pw, cr, func(written int64, err error) { - log.Debug("Async stdin cache fill: callback received") + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, f) elapsed := time.Since(start) if err == nil { - log.Debug("Async stdin cache fill: completed", "copied", written, "elapsed", elapsed) + log.Debug("Async stdin cache fill: completed", lga.Copied, written, lga.Elapsed, elapsed) lg.WarnIfCloseError(log, "Close stdin cache", w) - lg.WarnIfCloseError(log, "Close stdin reader file", f) pw.Stop() return } - log.Error("Async stdin cache fill: failure", lga.Err, err, "copied", written, "elapsed", elapsed) + log.Error("Async stdin cache fill: failure", + lga.Err, err, + lga.Copied, written, + lga.Elapsed, elapsed, + ) pw.Stop() wErrFn(err) - // We deliberately don't close w here, because wErrFn handles that work. + // We deliberately don't close "w" here, because wErrFn handles that work. }) log.Debug("Async stdin cache fill: dispatched") - return r, nil + return nil } -// add file copies f to fs's cache, returning a reader which the +// addFile maps f to fs's cache, returning a reader which the // caller is responsible for closing. f is closed by this method. -// Do not add stdin via this function; instead use addStdinViaCopyAsync. -func (fs *Files) addFileViaMapFile(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { +// Do not add stdin via this function; instead use addStdin. +func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { log := lg.FromContext(ctx) log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, f) + if key == StdinHandle { // This is a programming error; the caller should have - // instead invoked addStdinViaCopyAsync. Probably should panic here. + // instead invoked addStdin. Probably should panic here. return nil, errz.New("illegal to add stdin via Files.addFile") } @@ -486,6 +288,9 @@ func (fs *Files) ReadAll(ctx context.Context, src *Source) ([]byte, error) { func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, error) { log := lg.FromContext(ctx).With(lga.Loc, loc) log.Debug("Files.newReader", lga.Loc, loc) + + locTyp := getLocType(loc) + if loc == StdinHandle { r, w, err := fs.fcache.Get(StdinHandle) log.Debug("Returned from fs.fcache.Get", lga.Err, err) @@ -500,29 +305,27 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro } if !fs.fcache.Exists(loc) { - // cache miss - f, err := fs.openLocation(ctx, loc) + r, _, err := fs.fcache.Get(loc) if err != nil { return nil, err } - // Note that addFile closes f - //r, err := fs.addFileViaMapFile(ctx, f, loc) - //r, err := fs.addFileViaAsyncFiller(ctx, f, loc) - r, err := fs.addFileViaCopyAsync(ctx, f, loc) - //r, err := fs.addFileOld(ctx, f, loc) - if err != nil { - return nil, err - } return r, nil } - r, _, err := fs.fcache.Get(loc) + // cache miss + f, err := fs.openLocation(ctx, loc) if err != nil { return nil, err } + // Note that addFile closes f + r, err := fs.addFile(ctx, f, loc) + if err != nil { + return nil, err + } return r, nil + } // openLocation returns a file for loc. It is the caller's @@ -533,24 +336,27 @@ func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) var err error fpath, ok = isFpath(loc) + if ok { + // we have a legitimate fpath + f, err := os.Open(fpath) + return f, errz.Err(err) + } + // It's not a local file path, maybe it's remote (http) + var u *url.URL + u, ok = httpURL(loc) if !ok { - // It's not a local file path, maybe it's remote (http) - var u *url.URL - u, ok = httpURL(loc) - if !ok { - // We're out of luck, it's not a valid file location - return nil, errz.Errorf("invalid src location: %s", loc) - } + // We're out of luck, it's not a valid file location + return nil, errz.Errorf("invalid src location: %s", loc) + } - // It's a remote file - fpath, err = fs.fetch(ctx, u.String()) - if err != nil { - return nil, err - } + // It's a remote file + fpath, err = fs.fetch(ctx, u.String()) + if err != nil { + return nil, err } - // we have a legitimate fpath - return fs.openFile(fpath) + f, err := os.Open(fpath) + return f, errz.Err(err) } // openFile opens the file at fpath. It is the caller's @@ -614,6 +420,7 @@ func (fs *Files) CleanupE(fn func() error) { // DriverType returns the driver type of loc. func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, error) { + log := lg.FromContext(ctx).With(lga.Loc, loc) ploc, err := parseLoc(loc) if err != nil { return drivertype.None, err @@ -626,20 +433,12 @@ func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, e if ploc.ext != "" { mtype := mime.TypeByExtension(ploc.ext) if mtype == "" { - fs.log.Debug( - "unknown mime type", - lga.Type, mtype, - lga.Loc, loc, - ) + log.Debug("unknown mime type", lga.Type, mtype) } else { if typ, ok := typeFromMediaType(mtype); ok { return typ, nil } - fs.log.Debug( - "unknown driver type for media type", - lga.Type, mtype, - lga.Loc, loc, - ) + log.Debug("unknown driver type for media type", lga.Type, mtype) } } @@ -657,12 +456,11 @@ func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, e } func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Type, ok bool, err error) { - log := lg.FromContext(ctx).With(lga.Loc, loc) - start := time.Now() - log.Debug("Files.detectType") if len(fs.detectFns) == 0 { return drivertype.None, false, nil } + log := lg.FromContext(ctx).With(lga.Loc, loc) + start := time.Now() type result struct { typ drivertype.Type @@ -684,25 +482,11 @@ func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Typ } g, gCtx := errgroup.WithContext(ctx) - //gCtx = lg.NewContext(gCtx, fs.log) - for i, detectFn := range fs.detectFns { - i := i + for _, detectFn := range fs.detectFns { detectFn := detectFn g.Go(func() error { - start := time.Now() - defer func() { - lg.FromContext(gCtx).Debug( - "detectType: detectFn complete", - lga.Type, - detectFn, - lga.Index, i, - lga.Elapsed, - time.Since(start), - ) - - }() select { case <-gCtx.Done(): return gCtx.Err() @@ -721,9 +505,14 @@ func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Typ }) } + // REVISIT: We shouldn't have to wait for all goroutines to complete. + // This logic could be refactored to return as soon as a single + // goroutine returns a score >= 1.0 (then cancelling the other detector + // goroutines). + err = g.Wait() if err != nil { - fs.log.Error(err.Error()) + log.Error(err.Error()) return drivertype.None, false, errz.Err(err) } close(resultCh) @@ -819,46 +608,3 @@ func httpURL(s string) (u *url.URL, ok bool) { return u, true } - -// CacheDirFor gets the cache dir for handle, creating it if necessary. -// If handle is empty or invalid, a random value is generated. -func CacheDirFor(src *Source) (dir string, err error) { - handle := src.Handle - switch handle { - case "": - handle = "@cache_" + stringz.UniqN(32) - case StdinHandle: - // stdin is different input every time, so we need a unique - // cache dir. - handle += "_" + stringz.UniqN(32) - default: - if err = ValidHandle(handle); err != nil { - return "", errz.Wrapf(err, "open cache dir: invalid handle: %s", handle) - } - } - - dir = CacheDirPath() - sanitized := Handle2SafePath(handle) - hash := src.Hash() - dir = filepath.Join(dir, "sources", sanitized, hash) - if err = os.MkdirAll(dir, 0o750); err != nil { - return "", errz.Wrapf(err, "open cache dir: %s", dir) - } - - return dir, nil -} - -// CacheDirPath returns the sq cache dir. This is generally -// in USER_CACHE_DIR/sq/cache, but could also be in TEMP_DIR/sq/cache -// or similar. It is not guaranteed that the returned dir exists -// or is accessible. -func CacheDirPath() (dir string) { - var err error - if dir, err = os.UserCacheDir(); err != nil { - // Some systems may not have a user cache dir, so we fall back - // to the system temp dir. - dir = os.TempDir() - } - dir = filepath.Join(dir, "sq", "cache") - return dir -} diff --git a/libsq/source/location.go b/libsq/source/location.go index 97baa4fc6..f5eff4cfb 100644 --- a/libsq/source/location.go +++ b/libsq/source/location.go @@ -13,6 +13,7 @@ import ( "github.com/neilotoole/sq/libsq/source/drivertype" ) +// dbSchemes is a list of known SQL driver schemes. var dbSchemes = []string{ "mysql", "sqlserver", @@ -320,3 +321,35 @@ func isFpath(loc string) (fpath string, ok bool) { return fpath, true } + +// locType is an enumeration of the various types of source location. +type locType string + +const ( + locTypeStdin = "stdin" + locTypeLocalFile = "local_file" + locTypeSQL = "sql" + locTypeRemoteFile = "remote_file" + locTypeUnknown = "unknown" +) + +// getLocType returns the type of loc, or locTypeUnknown if it +// can't be determined. +func getLocType(loc string) locType { + switch { + case loc == StdinHandle: + // Convention: the "location" of stdin is always "@stdin" + return locTypeStdin + case IsSQLLocation(loc): + return locTypeSQL + case strings.HasPrefix(loc, "http://"), + strings.HasPrefix(loc, "https://"): + return locTypeRemoteFile + default: + } + + if _, err := filepath.Abs(loc); err != nil { + return locTypeUnknown + } + return locTypeLocalFile +} From 3475a99f0121d3e0c6eac79123e6a7ce28c0fd9a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 2 Dec 2023 18:20:16 -0700 Subject: [PATCH 033/195] moved json driver over to OpenIngest mechanism --- drivers/csv/csv.go | 5 +- drivers/json/{import.go => ingest.go} | 16 +-- .../json/{import_json.go => ingest_json.go} | 2 +- .../json/{import_jsona.go => ingest_jsona.go} | 2 +- .../json/{import_jsonl.go => ingest_jsonl.go} | 2 +- .../json/{import_test.go => ingest_test.go} | 0 drivers/json/internal_test.go | 12 +- drivers/json/json.go | 108 ++++++++++++------ libsq/source/files.go | 47 ++++++++ 9 files changed, 136 insertions(+), 58 deletions(-) rename drivers/json/{import.go => ingest.go} (97%) rename drivers/json/{import_json.go => ingest_json.go} (99%) rename drivers/json/{import_jsona.go => ingest_jsona.go} (99%) rename drivers/json/{import_jsonl.go => ingest_jsonl.go} (98%) rename drivers/json/{import_test.go => ingest_test.go} (100%) diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index a720dd22e..4960d1543 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -84,12 +84,11 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er return ingestCSV(ctx, src, openFn, destPool) } - backingPool, err := d.scratcher.OpenIngest(ctx, src, ingestFn, allowCache) - if err != nil { + var err error + if p.impl, err = d.scratcher.OpenIngest(ctx, src, ingestFn, allowCache); err != nil { return nil, err } - p.impl = backingPool return p, nil } diff --git a/drivers/json/import.go b/drivers/json/ingest.go similarity index 97% rename from drivers/json/import.go rename to drivers/json/ingest.go index a43cafd8d..053c36ef8 100644 --- a/drivers/json/import.go +++ b/drivers/json/ingest.go @@ -1,6 +1,6 @@ package json -// import.go contains functionality common to the +// ingest.go contains functionality common to the // various JSON import mechanisms. import ( @@ -24,10 +24,10 @@ import ( "github.com/neilotoole/sq/libsq/source" ) -// importJob describes a single import job, where the JSON +// ingestJob describes a single ingest job, where the JSON // at fromSrc is read via openFn and the resulting records // are written to destPool. -type importJob struct { +type ingestJob struct { fromSrc *source.Source openFn source.FileOpenFunc destPool driver.Pool @@ -40,16 +40,16 @@ type importJob struct { // imported as fields of the single top-level table, with a // scoped column name. // - // TODO: flatten come from src.Options + // TODO: flatten should come from src.Options flatten bool } -type importFunc func(ctx context.Context, job importJob) error +type ingestFunc func(ctx context.Context, job ingestJob) error var ( - _ importFunc = importJSON - _ importFunc = importJSONA - _ importFunc = importJSONL + _ ingestFunc = ingestJSON + _ ingestFunc = ingestJSONA + _ ingestFunc = ingestJSONL ) // getRecMeta returns record.Meta to use with RecordWriter.Open. diff --git a/drivers/json/import_json.go b/drivers/json/ingest_json.go similarity index 99% rename from drivers/json/import_json.go rename to drivers/json/ingest_json.go index beb9e5a45..bca1bb7fc 100644 --- a/drivers/json/import_json.go +++ b/drivers/json/ingest_json.go @@ -132,7 +132,7 @@ func DetectJSON(sampleSize int) source.DriverDetectFunc { } } -func importJSON(ctx context.Context, job importJob) error { +func ingestJSON(ctx context.Context, job ingestJob) error { log := lg.FromContext(ctx) r, err := job.openFn(ctx) diff --git a/drivers/json/import_jsona.go b/drivers/json/ingest_jsona.go similarity index 99% rename from drivers/json/import_jsona.go rename to drivers/json/ingest_jsona.go index 01ae0401f..0c21f3216 100644 --- a/drivers/json/import_jsona.go +++ b/drivers/json/ingest_jsona.go @@ -95,7 +95,7 @@ func DetectJSONA(sampleSize int) source.DriverDetectFunc { } } -func importJSONA(ctx context.Context, job importJob) error { +func ingestJSONA(ctx context.Context, job ingestJob) error { log := lg.FromContext(ctx) predictR, err := job.openFn(ctx) diff --git a/drivers/json/import_jsonl.go b/drivers/json/ingest_jsonl.go similarity index 98% rename from drivers/json/import_jsonl.go rename to drivers/json/ingest_jsonl.go index b029de1d7..07484e02c 100644 --- a/drivers/json/import_jsonl.go +++ b/drivers/json/ingest_jsonl.go @@ -83,7 +83,7 @@ func DetectJSONL(sampleSize int) source.DriverDetectFunc { // DetectJSONL implements source.DriverDetectFunc. -func importJSONL(ctx context.Context, job importJob) error { //nolint:gocognit +func ingestJSONL(ctx context.Context, job ingestJob) error { //nolint:gocognit log := lg.FromContext(ctx) r, err := job.openFn(ctx) diff --git a/drivers/json/import_test.go b/drivers/json/ingest_test.go similarity index 100% rename from drivers/json/import_test.go rename to drivers/json/ingest_test.go diff --git a/drivers/json/internal_test.go b/drivers/json/internal_test.go index 71b028e25..c95a64835 100644 --- a/drivers/json/internal_test.go +++ b/drivers/json/internal_test.go @@ -17,23 +17,23 @@ import ( // export for testing. var ( - ImportJSON = importJSON - ImportJSONA = importJSONA - ImportJSONL = importJSONL + ImportJSON = ingestJSON + ImportJSONA = ingestJSONA + ImportJSONL = ingestJSONL ColumnOrderFlat = columnOrderFlat NewImportJob = newImportJob ) -// newImportJob is a constructor for the unexported importJob type. +// newImportJob is a constructor for the unexported ingestJob type. // If sampleSize <= 0, a default value is used. func newImportJob(fromSrc *source.Source, openFn source.FileOpenFunc, destPool driver.Pool, sampleSize int, flatten bool, -) importJob { +) ingestJob { if sampleSize <= 0 { sampleSize = driver.OptIngestSampleSize.Get(fromSrc.Options) } - return importJob{ + return ingestJob{ fromSrc: fromSrc, openFn: openFn, destPool: destPool, diff --git a/drivers/json/json.go b/drivers/json/json.go index 22c03e53a..0f111d54b 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -8,6 +8,7 @@ package json import ( "context" "database/sql" + "github.com/neilotoole/sq/libsq/core/options" "log/slog" "github.com/neilotoole/sq/libsq/core/cleanup" @@ -41,33 +42,31 @@ type Provider struct { // DriverFor implements driver.Provider. func (d *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { - var importFn importFunc + var ingestFn ingestFunc switch typ { //nolint:exhaustive case TypeJSON: - importFn = importJSON + ingestFn = ingestJSON case TypeJSONA: - importFn = importJSONA + ingestFn = ingestJSONA case TypeJSONL: - importFn = importJSONL + ingestFn = ingestJSONL default: return nil, errz.Errorf("unsupported driver type {%s}", typ) } return &driveri{ - log: d.Log, typ: typ, scratcher: d.Scratcher, files: d.Files, - importFn: importFn, + ingestFn: ingestFn, }, nil } // Driver implements driver.Driver. type driveri struct { - log *slog.Logger typ drivertype.Type - importFn importFunc + ingestFn ingestFunc scratcher driver.ScratchPoolOpener files *source.Files } @@ -93,41 +92,72 @@ func (d *driveri) DriverMetadata() driver.Metadata { // Open implements driver.PoolOpener. func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { - lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) - - p := &pool{log: d.log, src: src, clnup: cleanup.New(), files: d.files} - - r, err := d.files.Open(ctx, src) - if err != nil { - return nil, err - } - - p.impl, err = d.scratcher.OpenScratchFor(ctx, src) - if err != nil { - lg.WarnIfCloseError(d.log, lgm.CloseFileReader, r) - lg.WarnIfFuncError(d.log, lgm.CloseDB, p.clnup.Run) - return nil, err + log := lg.FromContext(ctx) + log.Debug(lgm.OpenSrc, lga.Src, src) + + p := &pool{ + log: log, + src: src, + clnup: cleanup.New(), + files: d.files, } - job := importJob{ - fromSrc: src, - openFn: d.files.OpenFunc(src), - destPool: p.impl, - sampleSize: driver.OptIngestSampleSize.Get(src.Options), - flatten: true, + allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) + + //r, err := d.files.Open(ctx, src) + //if err != nil { + // return nil, err + //} + // + //p.impl, err = d.scratcher.OpenScratchFor(ctx, src) + //if err != nil { + // lg.WarnIfCloseError(log, lgm.CloseFileReader, r) + // lg.WarnIfFuncError(log, lgm.CloseDB, p.clnup.Run) + // return nil, err + //} + + //job := ingestJob{ + // fromSrc: src, + // openFn: d.files.OpenFunc(src), + // destPool: p.impl, + // sampleSize: driver.OptIngestSampleSize.Get(src.Options), + // flatten: true, + //} + + ingestFn := func(ctx context.Context, destPool driver.Pool) error { + job := ingestJob{ + fromSrc: src, + openFn: d.files.OpenFunc(src), + destPool: destPool, + sampleSize: driver.OptIngestSampleSize.Get(src.Options), + flatten: true, + } + + return d.ingestFn(ctx, job) + + //openFn := d.files.OpenFunc(src) + //log.Debug("Ingest func invoked", lga.Src, src) + //return ingestCSV(ctx, src, openFn, destPool) } - err = d.importFn(ctx, job) - if err != nil { - lg.WarnIfCloseError(d.log, lgm.CloseFileReader, r) - lg.WarnIfFuncError(d.log, lgm.CloseDB, p.clnup.Run) + var err error + if p.impl, err = d.scratcher.OpenIngest(ctx, src, ingestFn, allowCache); err != nil { return nil, err } - err = r.Close() - if err != nil { - return nil, err - } + return p, nil + // + //err = d.importFn(ctx, job) + //if err != nil { + // lg.WarnIfCloseError(log, lgm.CloseFileReader, r) + // lg.WarnIfFuncError(log, lgm.CloseDB, p.clnup.Run) + // return nil, err + //} + // + //err = r.Close() + //if err != nil { + // return nil, err + //} return p, nil } @@ -148,13 +178,15 @@ func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { // Ping implements driver.Driver. func (d *driveri) Ping(ctx context.Context, src *source.Source) error { - d.log.Debug("Ping source", lga.Src, src) + log := lg.FromContext(ctx).With(lga.Src, src) + log.Debug("Ping source") + // FIXME: this should call d.files.Ping r, err := d.files.Open(ctx, src) if err != nil { return err } - defer lg.WarnIfCloseError(d.log, lgm.CloseFileReader, r) + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) return nil } diff --git a/libsq/source/files.go b/libsq/source/files.go index ade8e010e..39f8ae037 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -135,6 +135,10 @@ func (fs *Files) Size(ctx context.Context, src *Source) (size int64, err error) // is later accessible via fs.Open(src) where src.Handle // is StdinHandle; f's type can be detected via DetectStdinType. // Note that f is closed by this method. +// +// FIXME: AddStdin is probably not necessary, we can just do it +// on the fly in newReader? Or do we provide this because "stdin" +// can be something other than os.Stdin, e.g. via a flag? func (fs *Files) AddStdin(ctx context.Context, f *os.File) error { fs.mu.Lock() defer fs.mu.Unlock() @@ -290,6 +294,49 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro log.Debug("Files.newReader", lga.Loc, loc) locTyp := getLocType(loc) + switch locTyp { + case locTypeUnknown: + return nil, errz.Errorf("unknown source location type: %s", loc) + case locTypeSQL: + return nil, errz.Errorf("cannot read SQL source: %s", loc) + case locTypeStdin: + r, w, err := fs.fcache.Get(StdinHandle) + if err != nil { + return nil, errz.Err(err) + } + if w != nil { + return nil, errz.New("@stdin not cached: has AddStdin been invoked yet?") + } + + return r, nil + } + + // Well, it's either a local or remote file. + // Let's see if it's cached. + if fs.fcache.Exists(loc) { + r, _, err := fs.fcache.Get(loc) + if err != nil { + return nil, err + } + + return r, nil + } + + // It's not cached. + if locTyp == locTypeLocalFile { + f, err := os.Open(loc) + if err != nil { + return nil, errz.Err(err) + } + // fs.addFile closes f, so we don't have to do it. + r, err := fs.addFile(ctx, f, loc) + if err != nil { + return nil, err + } + return r, nil + } + + // It's an uncached remote file. if loc == StdinHandle { r, w, err := fs.fcache.Get(StdinHandle) From 7ac88cd1fbd21697ac6cbfc29b0fe21efb2100dd Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 2 Dec 2023 18:45:26 -0700 Subject: [PATCH 034/195] xlsx import broken: considering getting rid of deferred ingest --- drivers/xlsx/{database.go => pool.go} | 56 +++++++++++++++++++-------- drivers/xlsx/xlsx.go | 25 ++++++------ 2 files changed, 52 insertions(+), 29 deletions(-) rename drivers/xlsx/{database.go => pool.go} (73%) diff --git a/drivers/xlsx/database.go b/drivers/xlsx/pool.go similarity index 73% rename from drivers/xlsx/database.go rename to drivers/xlsx/pool.go index 7d2d12673..f05727b20 100644 --- a/drivers/xlsx/database.go +++ b/drivers/xlsx/pool.go @@ -26,6 +26,7 @@ type pool struct { src *source.Source files *source.Files + scratcher driver.ScratchPoolOpener scratchPool driver.Pool clnup *cleanup.Cleanup @@ -36,6 +37,13 @@ type pool struct { // ingestSheetNames is the list of sheet names to ingest. When empty, // all sheets should be ingested. The key use of ingestSheetNames // is with pool.TableMetadata, so that only the relevant table is ingested. + // + // FIXME: Verify how ingestSheetNames interacts with the deferred + // ingest and caching mechanisms. E.g. if we ingest a single sheet, + // and then later ingest the entire workbook, will the cache DB be + // accurate? + // ACTUALLY: We can get rid of this entirely, because with caching, + // there's no longer really a problem. ingestSheetNames []string } @@ -56,25 +64,33 @@ func (p *pool) doIngest(ctx context.Context, includeSheetNames []string) error { // the context being passed down the stack (in particular to ingestXLSX) // has the source's options on it. ctx = options.NewContext(ctx, options.Merge(options.FromContext(ctx), p.src.Options)) - - r, err := p.files.Open(ctx, p.src) - if err != nil { - return err + allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) + + ingestFn := func(ctx context.Context, destPool driver.Pool) error { + log.Debug("Ingest XLSX", lga.Src, p.src) + //openFn := p.files.OpenFunc(p.src) + r, err := p.files.Open(ctx, p.src) + if err != nil { + return err + } + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) + + xfile, err := excelize.OpenReader(r, excelize.Options{RawCellValue: false}) + if err != nil { + return err + } + + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, xfile) + + if err = ingestXLSX(ctx, p.src, destPool, xfile, includeSheetNames); err != nil { + lg.WarnIfError(log, lgm.CloseDB, p.clnup.Run()) + return err + } + return nil } - defer lg.WarnIfCloseError(p.log, lgm.CloseFileReader, r) - xfile, err := excelize.OpenReader(r, excelize.Options{RawCellValue: false}) - if err != nil { - return err - } - - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, xfile) - - err = ingestXLSX(ctx, p.src, p.scratchPool, xfile, includeSheetNames) - if err != nil { - lg.WarnIfError(p.log, lgm.CloseDB, p.clnup.Run()) - return err - } + var err error + p.scratchPool, err = p.scratcher.OpenIngest(ctx, p.src, ingestFn, allowCache) return err } @@ -92,6 +108,12 @@ func (p *pool) DB(ctx context.Context) (*sql.DB, error) { // SQLDriver implements driver.Pool. func (p *pool) SQLDriver() driver.SQLDriver { + p.mu.Lock() + defer p.mu.Unlock() + + if err := p.checkIngest(ctx); err != nil { + return nil, err + } return p.scratchPool.SQLDriver() } diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index da74b6903..e3999aaff 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -61,20 +61,21 @@ func (d *Driver) DriverMetadata() driver.Metadata { func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) - scratchPool, err := d.scratcher.OpenScratchFor(ctx, src) - if err != nil { - return nil, err - } - - clnup := cleanup.New() - clnup.AddE(scratchPool.Close) + //scratchPool, err := d.scratcher.OpenScratchFor(ctx, src) + //if err != nil { + // return nil, err + //} + // + //clnup := cleanup.New() + //clnup.AddE(scratchPool.Close) p := &pool{ - log: d.log, - src: src, - scratchPool: scratchPool, - files: d.files, - clnup: clnup, + log: d.log, + scratcher: d.scratcher, + src: src, + //scratchPool: scratchPool, + files: d.files, + clnup: cleanup.New(), } return p, nil From f2a137a4b41999e67cbb4e96f5f8e5f01ec6e040 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 07:03:25 -0700 Subject: [PATCH 035/195] progress working, except for renderDelay bug --- cli/output.go | 7 +- drivers/xlsx/ingest.go | 3 +- drivers/xlsx/pool.go | 115 ++------------------------------ drivers/xlsx/xlsx.go | 48 ++++++++----- libsq/core/progress/progress.go | 85 +++++++++++++++++------ libsq/driver/sources.go | 2 + testh/testh.go | 6 ++ 7 files changed, 121 insertions(+), 145 deletions(-) diff --git a/cli/output.go b/cli/output.go index 7d1470361..8eef653b0 100644 --- a/cli/output.go +++ b/cli/output.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "github.com/neilotoole/sq/libsq/core/lg" "io" "os" "strings" @@ -463,7 +464,11 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer prog := progress.New(ctx, errOut2, renderDelay, progColors) // On first write to stdout, we remove the progress widget. - out2 = ioz.NotifyOnceWriter(out2, prog.Wait) + //out2 = ioz.NotifyOnceWriter(out2, prog.Wait) + out2 = ioz.NotifyOnceWriter(out2, func() { + lg.FromContext(ctx).Debug("Notify once invoked") + prog.Wait() + }) cmd.SetContext(progress.NewContext(ctx, prog)) } diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 1e60acebf..60f027025 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -12,7 +12,6 @@ import ( excelize "github.com/xuri/excelize/v2" "golang.org/x/sync/errgroup" - "github.com/neilotoole/sq/libsq" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" @@ -192,7 +191,7 @@ func ingestSheetToTable(ctx context.Context, scratchPool driver.Pool, sheetTbl * batchSize := driver.MaxBatchRows(drvr, len(destColKinds)) bi, err := driver.NewBatchInsert( ctx, - libsq.MsgIngestRecords, + "Ingest "+sheet.name, drvr, conn, tblDef.Name, diff --git a/drivers/xlsx/pool.go b/drivers/xlsx/pool.go index f05727b20..86f14c7b1 100644 --- a/drivers/xlsx/pool.go +++ b/drivers/xlsx/pool.go @@ -3,19 +3,12 @@ package xlsx import ( "context" "database/sql" - "log/slog" - "sync" - - excelize "github.com/xuri/excelize/v2" - - "github.com/neilotoole/sq/libsq/core/cleanup" - "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" - "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/metadata" + "log/slog" ) // pool implements driver.Pool. It implements a deferred ingest @@ -26,95 +19,18 @@ type pool struct { src *source.Source files *source.Files - scratcher driver.ScratchPoolOpener - scratchPool driver.Pool - clnup *cleanup.Cleanup - - mu sync.Mutex - ingestOnce sync.Once - ingestErr error - - // ingestSheetNames is the list of sheet names to ingest. When empty, - // all sheets should be ingested. The key use of ingestSheetNames - // is with pool.TableMetadata, so that only the relevant table is ingested. - // - // FIXME: Verify how ingestSheetNames interacts with the deferred - // ingest and caching mechanisms. E.g. if we ingest a single sheet, - // and then later ingest the entire workbook, will the cache DB be - // accurate? - // ACTUALLY: We can get rid of this entirely, because with caching, - // there's no longer really a problem. - ingestSheetNames []string -} - -// checkIngest performs data ingestion if not already done. -func (p *pool) checkIngest(ctx context.Context) error { - p.ingestOnce.Do(func() { - p.ingestErr = p.doIngest(ctx, p.ingestSheetNames) - }) - - return p.ingestErr -} - -// doIngest performs data ingest. It must only be invoked from checkIngest. -func (p *pool) doIngest(ctx context.Context, includeSheetNames []string) error { - log := lg.FromContext(ctx) - - // Because of the deferred ingest mechanism, we need to ensure that - // the context being passed down the stack (in particular to ingestXLSX) - // has the source's options on it. - ctx = options.NewContext(ctx, options.Merge(options.FromContext(ctx), p.src.Options)) - allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) - - ingestFn := func(ctx context.Context, destPool driver.Pool) error { - log.Debug("Ingest XLSX", lga.Src, p.src) - //openFn := p.files.OpenFunc(p.src) - r, err := p.files.Open(ctx, p.src) - if err != nil { - return err - } - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - - xfile, err := excelize.OpenReader(r, excelize.Options{RawCellValue: false}) - if err != nil { - return err - } - - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, xfile) - - if err = ingestXLSX(ctx, p.src, destPool, xfile, includeSheetNames); err != nil { - lg.WarnIfError(log, lgm.CloseDB, p.clnup.Run()) - return err - } - return nil - } - - var err error - p.scratchPool, err = p.scratcher.OpenIngest(ctx, p.src, ingestFn, allowCache) - return err + backingPool driver.Pool } // DB implements driver.Pool. func (p *pool) DB(ctx context.Context) (*sql.DB, error) { - p.mu.Lock() - defer p.mu.Unlock() - - if err := p.checkIngest(ctx); err != nil { - return nil, err - } - return p.scratchPool.DB(ctx) + return p.backingPool.DB(ctx) } // SQLDriver implements driver.Pool. func (p *pool) SQLDriver() driver.SQLDriver { - p.mu.Lock() - defer p.mu.Unlock() - - if err := p.checkIngest(ctx); err != nil { - return nil, err - } - return p.scratchPool.SQLDriver() + return p.backingPool.SQLDriver() } // Source implements driver.Pool. @@ -124,14 +40,7 @@ func (p *pool) Source() *source.Source { // SourceMetadata implements driver.Pool. func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - p.mu.Lock() - defer p.mu.Unlock() - - if err := p.checkIngest(ctx); err != nil { - return nil, err - } - - md, err := p.scratchPool.SourceMetadata(ctx, noSchema) + md, err := p.backingPool.SourceMetadata(ctx, noSchema) if err != nil { return nil, err } @@ -153,22 +62,12 @@ func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou // TableMetadata implements driver.Pool. func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - p.mu.Lock() - defer p.mu.Unlock() - - p.ingestSheetNames = []string{tblName} - if err := p.checkIngest(ctx); err != nil { - return nil, err - } - - return p.scratchPool.TableMetadata(ctx, tblName) + return p.backingPool.TableMetadata(ctx, tblName) } // Close implements driver.Pool. func (p *pool) Close() error { p.log.Debug(lgm.CloseDB, lga.Handle, p.src.Handle) - // No need to explicitly invoke c.scratchPool.Close because - // that's already added to c.clnup - return p.clnup.Run() + return p.backingPool.Close() } diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index e3999aaff..9cd862ec5 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -3,11 +3,11 @@ package xlsx import ( "context" + "github.com/neilotoole/sq/libsq/core/options" "log/slog" excelize "github.com/xuri/excelize/v2" - "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" @@ -59,23 +59,41 @@ func (d *Driver) DriverMetadata() driver.Metadata { // Open implements driver.PoolOpener. func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { - lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) - - //scratchPool, err := d.scratcher.OpenScratchFor(ctx, src) - //if err != nil { - // return nil, err - //} - // - //clnup := cleanup.New() - //clnup.AddE(scratchPool.Close) + log := lg.FromContext(ctx).With(lga.Src, src) + log.Debug(lgm.OpenSrc, lga.Src, src) p := &pool{ - log: d.log, - scratcher: d.scratcher, - src: src, - //scratchPool: scratchPool, + log: log, + src: src, files: d.files, - clnup: cleanup.New(), + } + + allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) + + ingestFn := func(ctx context.Context, destPool driver.Pool) error { + log.Debug("Ingest XLSX", lga.Src, p.src) + r, err := p.files.Open(ctx, p.src) + if err != nil { + return err + } + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) + + xfile, err := excelize.OpenReader(r, excelize.Options{RawCellValue: false}) + if err != nil { + return err + } + + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, xfile) + + if err = ingestXLSX(ctx, p.src, destPool, xfile, nil); err != nil { + return err + } + return nil + } + + var err error + if p.backingPool, err = d.scratcher.OpenIngest(ctx, p.src, ingestFn, allowCache); err != nil { + return nil, err } return p, nil diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 20d9187c4..71125d08b 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -2,7 +2,9 @@ package progress import ( "context" + "github.com/neilotoole/sq/libsq/core/lg" "io" + "os" "sync" "time" @@ -11,8 +13,6 @@ import ( "github.com/fatih/color" mpb "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" - - "github.com/neilotoole/sq/libsq/core/cleanup" ) type runKey struct{} @@ -60,7 +60,11 @@ const ( // The Progress is lazily initialized, and thus the delay clock doesn't // start ticking until the first call to one of the Progress.NewX methods. func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors) *Progress { + lg.FromContext(ctx).Error("New progress", "delay", delay) + var cancelFn context.CancelFunc + ogCtx := ctx + _ = ogCtx ctx, cancelFn = context.WithCancel(ctx) if colors == nil { @@ -68,23 +72,33 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors } p := &Progress{ - ctx: ctx, - mu: sync.Mutex{}, - colors: colors, - cleanup: cleanup.New(), + ctx: ctx, + mu: sync.Mutex{}, + colors: colors, + //cleanup: cleanup.New(), cancelFn: cancelFn, + bars: make([]*Bar, 0), } p.pcInit = func() { opts := []mpb.ContainerOption{ + mpb.WithDebugOutput(os.Stdout), mpb.WithOutput(out), mpb.WithWidth(boxWidth), - mpb.WithRefreshRate(refreshRate), - mpb.WithAutoRefresh(), // Needed for color in Windows, apparently + //mpb.WithRefreshRate(refreshRate), + //mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } if delay > 0 { - opts = append(opts, mpb.WithRenderDelay(renderDelay(ctx, delay))) + delayCh := renderDelay(ctx, delay) + opts = append(opts, mpb.WithRenderDelay(delayCh)) + p.delayCh = delayCh + } else { + delayCh := make(chan struct{}) + close(delayCh) + p.delayCh = delayCh } + lg.FromContext(ctx).Debug("Render delay", "delay", delay) + p.pc = mpb.NewWithContext(ctx, opts...) p.pcInit = nil } @@ -108,8 +122,13 @@ type Progress struct { // pcInit is the func that lazily initializes pc. pcInit func() - colors *Colors - cleanup *cleanup.Cleanup + // delayCh controls the rendering delay: rendering can + // start as soon as delayCh is closed. + delayCh <-chan struct{} + + colors *Colors + //cleanup *cleanup.Cleanup + bars []*Bar cancelFn context.CancelFunc } @@ -129,13 +148,20 @@ func (p *Progress) Wait() { return } - if p.cleanup.Len() == 0 { + if len(p.bars) == 0 { return } p.cancelFn() - // Invoking cleanup will call Bar.Stop on all the bars. - _ = p.cleanup.Run() + + for _, bar := range p.bars { + bar.bar.Abort(true) + } + + //for _, bar := range p.bars { + // bar.bar.Wait() + //} + p.pc.Wait() } @@ -221,6 +247,8 @@ func (p *Progress) newBar(msg string, total int64, return nil } + lg.FromContext(p.ctx).Debug("New bar", "msg", msg, "total", total) + select { case <-p.ctx.Done(): return nil @@ -245,11 +273,27 @@ func (p *Progress) newBar(msg string, total int64, mpb.BarRemoveOnComplete(), ) - b := &Bar{bar: bar} - p.cleanup.Add(b.Stop) + b := &Bar{p: p, bar: bar} + p.bars = append(p.bars, b) return b } +func (p *Progress) barStopped(b *Bar) { + if p == nil { + return + } + + p.mu.Lock() + defer p.mu.Unlock() + + for i, bar := range p.bars { + if bar == b { + p.bars = append(p.bars[:i], p.bars[i+1:]...) + return + } + } +} + func spinnerStyle(c *color.Color) mpb.SpinnerStyleComposer { // TODO: should use ascii chars? frames := []string{"∙∙∙", "●∙∙", "●∙∙", "∙●∙", "∙●∙", "∙∙●", "∙∙●", "∙∙∙"} @@ -284,12 +328,14 @@ func barStyle(c *color.Color) mpb.BarStyleComposer { // the bar is complete, the caller should invoke [Bar.Stop]. All // methods are safe to call on a nil Bar. type Bar struct { + p *Progress bar *mpb.Bar } // IncrBy increments progress by amount of n. It is safe to // call IncrBy on a nil Bar. func (b *Bar) IncrBy(n int) { + //time.Sleep(time.Millisecond * 10) if b == nil { return } @@ -302,24 +348,25 @@ func (b *Bar) Stop() { if b == nil { return } - b.bar.SetTotal(-1, true) b.bar.Abort(true) b.bar.Wait() + b.p.barStopped(b) } // renderDelay returns a channel that will be closed after d, // or if ctx is done. func renderDelay(ctx context.Context, d time.Duration) <-chan struct{} { ch := make(chan struct{}) - + t := time.NewTimer(d) go func() { defer close(ch) - t := time.NewTimer(d) defer t.Stop() select { case <-ctx.Done(): + lg.FromContext(ctx).Debug("Render delay via ctx.Done") case <-t.C: + lg.FromContext(ctx).Debug("Render delay via timer") } }() return ch diff --git a/libsq/driver/sources.go b/libsq/driver/sources.go index eab56f302..d4054caec 100644 --- a/libsq/driver/sources.go +++ b/libsq/driver/sources.go @@ -188,6 +188,7 @@ func (d *Pools) openIngestNoCache(ctx context.Context, src *source.Source, lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed, lga.Err, err, ) + lg.WarnIfCloseError(log, lgm.CloseDB, impl) } d.log.Debug("Ingest completed", @@ -256,6 +257,7 @@ func (d *Pools) openIngestCache(ctx context.Context, src *source.Source, lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed, lga.Err, err, ) + lg.WarnIfCloseError(log, lgm.CloseDB, impl) return nil, err } diff --git a/testh/testh.go b/testh/testh.go index 4522708f9..3805edd8a 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -134,6 +134,12 @@ func New(t testing.TB, opts ...Option) *Helper { h.Context = lg.NewContext(ctx, h.Log) + // Disable caching in tests, because there's all sorts of confounding + // situations with running tests in parallel with caching enabled, + // due to the fact that caching uses pid-based locking, and parallel tests + // share the same pid. + o := options.Options{driver.OptIngestCache.Key(): false} + h.Context = options.NewContext(h.Context, o) t.Cleanup(h.Close) return h } From 9e8a1d754b7489a9ab8da53a7f6d85fb1d484136 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 07:40:22 -0700 Subject: [PATCH 036/195] Linting --- .golangci.yml | 3 ++ cli/logging.go | 4 +-- cli/output.go | 4 +-- drivers/csv/detect_type.go | 2 +- drivers/html/doc.go | 11 ------- drivers/json/json.go | 40 +------------------------- drivers/userdriver/userdriver.go | 48 ++++++++++++++----------------- drivers/xlsx/detect.go | 6 ++-- drivers/xlsx/pool.go | 4 +-- drivers/xlsx/xlsx.go | 2 +- libsq/core/errz/errz.go | 5 ++++ libsq/core/ioz/ioz.go | 45 ++--------------------------- libsq/core/lg/devlog/devlog.go | 3 +- libsq/core/progress/progress.go | 19 ++++-------- libsq/core/progress/progressio.go | 1 - libsq/source/cache.go | 7 +++-- libsq/source/files.go | 15 +--------- libsq/source/source.go | 15 +++++++++- testh/testh.go | 1 + 19 files changed, 71 insertions(+), 164 deletions(-) delete mode 100644 drivers/html/doc.go diff --git a/.golangci.yml b/.golangci.yml index d6a694648..c321098ff 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -30,6 +30,9 @@ run: # This package is such a mess, and needs to be rewritten completely. - cli/output/tablew/internal + # The tint package is a fork of an upstream package. + - libsq/core/lg/devlog/tint + # Non-committed scratch dir - scratch diff --git a/cli/logging.go b/cli/logging.go index 35e3221c1..02099a93b 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -2,7 +2,6 @@ package cli import ( "context" - "github.com/neilotoole/sq/libsq/core/lg/devlog" "io" "log/slog" "os" @@ -16,6 +15,7 @@ import ( "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/devlog" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/userlogdir" "github.com/neilotoole/sq/libsq/core/options" @@ -93,7 +93,7 @@ func defaultLogging(ctx context.Context, osArgs []string, cfg *config.Config, closer = logFile.Close h = devlog.NewHandler(logFile, lvl) - //h = newJSONHandler(logFile, lvl) + // h = newJSONHandler(logFile, lvl) return slog.New(h), h, closer, nil } diff --git a/cli/output.go b/cli/output.go index 8eef653b0..f12abba94 100644 --- a/cli/output.go +++ b/cli/output.go @@ -2,7 +2,6 @@ package cli import ( "fmt" - "github.com/neilotoole/sq/libsq/core/lg" "io" "os" "strings" @@ -27,6 +26,7 @@ import ( "github.com/neilotoole/sq/cli/output/yamlw" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/progress" @@ -464,7 +464,7 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer prog := progress.New(ctx, errOut2, renderDelay, progColors) // On first write to stdout, we remove the progress widget. - //out2 = ioz.NotifyOnceWriter(out2, prog.Wait) + // out2 = ioz.NotifyOnceWriter(out2, prog.Wait) out2 = ioz.NotifyOnceWriter(out2, func() { lg.FromContext(ctx).Debug("Notify once invoked") prog.Wait() diff --git a/drivers/csv/detect_type.go b/drivers/csv/detect_type.go index 7e9db5171..a569a2f9a 100644 --- a/drivers/csv/detect_type.go +++ b/drivers/csv/detect_type.go @@ -4,13 +4,13 @@ import ( "context" "encoding/csv" "errors" - "github.com/neilotoole/sq/libsq/core/lg/lga" "io" "time" "github.com/neilotoole/sq/cli/output/csvw" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" diff --git a/drivers/html/doc.go b/drivers/html/doc.go deleted file mode 100644 index c05a820b8..000000000 --- a/drivers/html/doc.go +++ /dev/null @@ -1,11 +0,0 @@ -// Package html is the future home of the HTML table import driver. -// -// BRAINDUMP: -// A particular use case is this: -// In your browser, select a table, and copy that HTML. -// Then (on macOS): -// -// > pbpaste | sq .data --json -// -// Should output that HTML table as JSON, etc. -package html diff --git a/drivers/json/json.go b/drivers/json/json.go index 0f111d54b..497236d04 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -8,7 +8,6 @@ package json import ( "context" "database/sql" - "github.com/neilotoole/sq/libsq/core/options" "log/slog" "github.com/neilotoole/sq/libsq/core/cleanup" @@ -16,6 +15,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" @@ -104,26 +104,6 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) - //r, err := d.files.Open(ctx, src) - //if err != nil { - // return nil, err - //} - // - //p.impl, err = d.scratcher.OpenScratchFor(ctx, src) - //if err != nil { - // lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - // lg.WarnIfFuncError(log, lgm.CloseDB, p.clnup.Run) - // return nil, err - //} - - //job := ingestJob{ - // fromSrc: src, - // openFn: d.files.OpenFunc(src), - // destPool: p.impl, - // sampleSize: driver.OptIngestSampleSize.Get(src.Options), - // flatten: true, - //} - ingestFn := func(ctx context.Context, destPool driver.Pool) error { job := ingestJob{ fromSrc: src, @@ -134,10 +114,6 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er } return d.ingestFn(ctx, job) - - //openFn := d.files.OpenFunc(src) - //log.Debug("Ingest func invoked", lga.Src, src) - //return ingestCSV(ctx, src, openFn, destPool) } var err error @@ -145,20 +121,6 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er return nil, err } - return p, nil - // - //err = d.importFn(ctx, job) - //if err != nil { - // lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - // lg.WarnIfFuncError(log, lgm.CloseDB, p.clnup.Run) - // return nil, err - //} - // - //err = r.Close() - //if err != nil { - // return nil, err - //} - return p, nil } diff --git a/drivers/userdriver/userdriver.go b/drivers/userdriver/userdriver.go index e9c5cafb6..cc22c6d1c 100644 --- a/drivers/userdriver/userdriver.go +++ b/drivers/userdriver/userdriver.go @@ -11,11 +11,11 @@ import ( "io" "log/slog" - "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" @@ -47,7 +47,7 @@ func (p *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { typ: typ, def: p.DriverDef, scratcher: p.Scratcher, - importFn: p.ImportFn, + ingestFn: p.ImportFn, files: p.Files, }, nil } @@ -67,7 +67,7 @@ type driveri struct { def *DriverDef files *source.Files scratcher driver.ScratchPoolOpener - importFn ImportFunc + ingestFn ImportFunc } // DriverMetadata implements driver.Driver. @@ -82,30 +82,30 @@ func (d *driveri) DriverMetadata() driver.Metadata { // Open implements driver.PoolOpener. func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { - lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) + log := lg.FromContext(ctx).With(lga.Src, src) + log.Debug(lgm.OpenSrc) - clnup := cleanup.New() - - r, err := d.files.Open(ctx, src) - if err != nil { - return nil, err + p := &pool{ + log: d.log, + src: src, } - defer lg.WarnIfCloseError(d.log, lgm.CloseFileReader, r) + allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) - scratchDB, err := d.scratcher.OpenScratchFor(ctx, src) - if err != nil { - return nil, err + ingestFn := func(ctx context.Context, destPool driver.Pool) error { + r, err := d.files.Open(ctx, src) + if err != nil { + return err + } + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) + return d.ingestFn(ctx, d.def, r, destPool) } - clnup.AddE(scratchDB.Close) - err = d.importFn(ctx, d.def, r, scratchDB) - if err != nil { - lg.WarnIfFuncError(d.log, lgm.CloseDB, clnup.Run) - return nil, errz.Wrap(err, d.def.Name) + var err error + if p.impl, err = d.scratcher.OpenIngest(ctx, src, ingestFn, allowCache); err != nil { + return nil, err } - - return &pool{log: d.log, src: src, impl: scratchDB, clnup: clnup}, nil + return p, nil } // Truncate implements driver.Driver. @@ -145,10 +145,6 @@ type pool struct { log *slog.Logger src *source.Source impl driver.Pool - - // clnup will ultimately invoke impl.Close to dispose of - // the scratch DB. - clnup *cleanup.Cleanup } // DB implements driver.Pool. @@ -193,7 +189,5 @@ func (d *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou func (d *pool) Close() error { d.log.Debug(lgm.CloseDB, lga.Handle, d.src.Handle) - // We don't need to explicitly invoke c.impl.Close - // because that's already been added to c.cleanup. - return d.clnup.Run() + return d.impl.Close() } diff --git a/drivers/xlsx/detect.go b/drivers/xlsx/detect.go index b9885bc9a..eef63b89a 100644 --- a/drivers/xlsx/detect.go +++ b/drivers/xlsx/detect.go @@ -3,11 +3,12 @@ package xlsx import ( "context" "errors" - "github.com/h2non/filetype" - "github.com/h2non/filetype/matchers" "io" "slices" + "github.com/h2non/filetype" + "github.com/h2non/filetype/matchers" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" @@ -50,7 +51,6 @@ func DetectXLSX(ctx context.Context, openFn source.FileOpenFunc) (detected drive default: return drivertype.None, 0, nil } - } func detectHeaderRow(ctx context.Context, sheet *xSheet) (hasHeader bool, err error) { diff --git a/drivers/xlsx/pool.go b/drivers/xlsx/pool.go index 86f14c7b1..b3f398089 100644 --- a/drivers/xlsx/pool.go +++ b/drivers/xlsx/pool.go @@ -3,12 +3,13 @@ package xlsx import ( "context" "database/sql" + "log/slog" + "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/metadata" - "log/slog" ) // pool implements driver.Pool. It implements a deferred ingest @@ -24,7 +25,6 @@ type pool struct { // DB implements driver.Pool. func (p *pool) DB(ctx context.Context) (*sql.DB, error) { - return p.backingPool.DB(ctx) } diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index 9cd862ec5..163ac29a4 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -3,7 +3,6 @@ package xlsx import ( "context" - "github.com/neilotoole/sq/libsq/core/options" "log/slog" excelize "github.com/xuri/excelize/v2" @@ -12,6 +11,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index 1f610071c..b26f1088f 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -95,3 +95,8 @@ func IsErrContext(err error) bool { func IsErrContextDeadlineExceeded(err error) bool { return errors.Is(err, context.DeadlineExceeded) } + +// Tuple returns t and err, wrapping err with errz.Err. +func Tuple[T any](t T, err error) (T, error) { + return t, Err(err) +} diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 4dd13c63b..54dbc17a3 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -49,45 +49,6 @@ func CopyAsync(w io.Writer, r io.Reader, callback func(written int64, err error) }() } -// CopyAsyncFull asynchronously copies from r to w, invoking callback when done. -// If arg close is true and w is an io.WriterCloser, w is closed on successful -// completion of the copy (but it is not closed if an error occurs during write). -// If callback is nil, it is ignored. -func CopyAsyncFull(w io.Writer, r io.Reader, close bool, callback func(written int64, err error)) { - go func() { - written, err := io.Copy(w, r) - if err != nil { - if callback != nil { - callback(written, err) - } - return - } - // err is nil - if !close { - if callback != nil { - callback(written, nil) - } - return - } - - // err is nil, and close is true - wc, ok := w.(io.WriteCloser) - if !ok { - // It's not a write closer... this is basically a programming - // error to have set close to true when w is not an io.WriteCloser, - // but we'll handle it generously. - if callback != nil { - callback(written, nil) - } - return - } - - // err is nil, close is true, and w is a WriteCloser - err = wc.Close() - callback(written, err) - }() -} - // PrintFile reads file from name and writes it to stdout. func PrintFile(name string) error { return FPrintFile(os.Stdout, name) @@ -389,10 +350,8 @@ type WrittenWriter struct { // writer, blocking until writing concludes, either via invocation of // Close, or via an error in Write. func (w *WrittenWriter) Written() int64 { - select { - case <-w.done: - return w.c.Load() - } + <-w.done + return w.c.Load() } // Close implements io.WriteCloser. diff --git a/libsq/core/lg/devlog/devlog.go b/libsq/core/lg/devlog/devlog.go index bceac7f4e..810edea60 100644 --- a/libsq/core/lg/devlog/devlog.go +++ b/libsq/core/lg/devlog/devlog.go @@ -1,9 +1,10 @@ package devlog import ( - "github.com/neilotoole/sq/libsq/core/lg/devlog/tint" "io" "log/slog" + + "github.com/neilotoole/sq/libsq/core/lg/devlog/tint" ) const shortTimeFormat = `15:04:05.000000` diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 71125d08b..325a935c9 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -2,7 +2,6 @@ package progress import ( "context" - "github.com/neilotoole/sq/libsq/core/lg" "io" "os" "sync" @@ -13,6 +12,8 @@ import ( "github.com/fatih/color" mpb "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" + + "github.com/neilotoole/sq/libsq/core/lg" ) type runKey struct{} @@ -75,7 +76,7 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors ctx: ctx, mu: sync.Mutex{}, colors: colors, - //cleanup: cleanup.New(), + // cleanup: cleanup.New(), cancelFn: cancelFn, bars: make([]*Bar, 0), } @@ -85,8 +86,8 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors mpb.WithDebugOutput(os.Stdout), mpb.WithOutput(out), mpb.WithWidth(boxWidth), - //mpb.WithRefreshRate(refreshRate), - //mpb.WithAutoRefresh(), // Needed for color in Windows, apparently + // mpb.WithRefreshRate(refreshRate), + // mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } if delay > 0 { delayCh := renderDelay(ctx, delay) @@ -127,7 +128,7 @@ type Progress struct { delayCh <-chan struct{} colors *Colors - //cleanup *cleanup.Cleanup + // cleanup *cleanup.Cleanup bars []*Bar cancelFn context.CancelFunc @@ -158,10 +159,6 @@ func (p *Progress) Wait() { bar.bar.Abort(true) } - //for _, bar := range p.bars { - // bar.bar.Wait() - //} - p.pc.Wait() } @@ -312,14 +309,11 @@ func barStyle(c *color.Color) mpb.BarStyleComposer { } frames := []string{"∙", "●", "●", "●", "∙"} - //frames := []string{"∙∙∙", "●∙∙", "●∙∙", "∙●∙", "∙●∙", "∙∙●", "∙∙●", "∙∙∙"} return mpb.BarStyle(). Lbound(" ").Rbound(" "). Filler("∙").FillerMeta(clr). - //Refiller("x").RefillerMeta(clr). Padding(" "). - //Tip(`-`, `\`, `|`, `/`).TipMeta(clr). Tip(frames...).TipMeta(clr) } @@ -335,7 +329,6 @@ type Bar struct { // IncrBy increments progress by amount of n. It is safe to // call IncrBy on a nil Bar. func (b *Bar) IncrBy(n int) { - //time.Sleep(time.Millisecond * 10) if b == nil { return } diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index 39cc65103..c21c278fd 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -233,7 +233,6 @@ func (w writerWrapper) Close() error { // Stop implements [Writer] and is no-op. func (w writerWrapper) Stop() { - return } var _ Writer = (*progCopier)(nil) diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 3404616ce..f3dd4e190 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -1,10 +1,11 @@ package source import ( - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/stringz" "os" "path/filepath" + + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/stringz" ) // CacheDirFor gets the cache dir for handle, creating it if necessary. @@ -15,7 +16,7 @@ func CacheDirFor(src *Source) (dir string, err error) { case "": // FIXME: This is surely an error? return "", errz.Errorf("open cache dir: empty handle") - //handle = "@cache_" + stringz.UniqN(32) + // handle = "@cache_" + stringz.UniqN(32) case StdinHandle: // stdin is different input every time, so we need a unique // cache dir. In practice, stdin probably isn't using this function. diff --git a/libsq/source/files.go b/libsq/source/files.go index 39f8ae037..0100bc5ba 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -372,7 +372,6 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro return nil, err } return r, nil - } // openLocation returns a file for loc. It is the caller's @@ -385,8 +384,7 @@ func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) fpath, ok = isFpath(loc) if ok { // we have a legitimate fpath - f, err := os.Open(fpath) - return f, errz.Err(err) + return errz.Tuple(os.Open(fpath)) } // It's not a local file path, maybe it's remote (http) var u *url.URL @@ -406,17 +404,6 @@ func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) return f, errz.Err(err) } -// openFile opens the file at fpath. It is the caller's -// responsibility to close the returned file. -func (fs *Files) openFile(fpath string) (*os.File, error) { - f, err := os.OpenFile(fpath, os.O_RDWR, 0o666) - if err != nil { - return nil, errz.Err(err) - } - - return f, nil -} - // fetch ensures that loc exists locally as a file. This may // entail downloading the file via HTTPS etc. func (fs *Files) fetch(ctx context.Context, loc string) (fpath string, err error) { diff --git a/libsq/source/source.go b/libsq/source/source.go index 32f3cfc55..bd639e3ae 100644 --- a/libsq/source/source.go +++ b/libsq/source/source.go @@ -14,6 +14,7 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source/drivertype" ) @@ -91,10 +92,18 @@ type Source struct { // Options are additional params, typically empty. Options options.Options `yaml:"options,omitempty" json:"options,omitempty"` + + // Ephemeral is a flag that indicates that the source is ephemeral. This + // value is not persisted to config. It is used by the Source.Hash method, + // resulting in a different hash value for each ephemeral source. + Ephemeral bool } // Hash returns an SHA256 hash of all fields of s. The Source.Options -// field is ignored. If s is nil, the empty string is returned. +// field is ignored. If s is nil, the empty string is returned. If +// Source.Ephemeral is true, the hash value will be different for +// each invocation. This is useful for preventing cache collisions +// when using ephemeral sources. func (s *Source) Hash() string { if s == nil { return "" @@ -107,6 +116,10 @@ func (s *Source) Hash() string { buf.WriteString(s.Catalog) buf.WriteString(s.Schema) buf.WriteString(s.Options.Hash()) + if s.Ephemeral { + buf.WriteString(stringz.Uniq32()) + } + sum := sha256.Sum256(buf.Bytes()) return fmt.Sprintf("%x", sum) } diff --git a/testh/testh.go b/testh/testh.go index 3805edd8a..1d5d9b7e1 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -311,6 +311,7 @@ func (h *Helper) Source(handle string) *source.Source { src, err := h.coll.Get(handle) require.NoError(t, err, "source %s was not found in %s", handle, testsrc.PathSrcsConfig) + src.Ephemeral = true if src.Type == sqlite3.Type { // This could be easily generalized for CSV/XLSX etc. From 091e33b3ad42963dffa3ac0e5f0984df4c222eac Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 08:21:31 -0700 Subject: [PATCH 037/195] Refactoring --- cli/cmd_inspect.go | 2 +- cli/cmd_slq.go | 4 +- cli/cmd_sql.go | 4 +- cli/cmd_tbl.go | 4 +- cli/complete.go | 6 +- cli/diff/source.go | 2 +- cli/diff/table.go | 2 +- cli/run.go | 12 +- cli/run/run.go | 10 +- drivers/csv/csv.go | 20 ++-- drivers/csv/ingest.go | 14 +-- drivers/csv/insert.go | 6 +- drivers/json/ingest.go | 8 +- drivers/json/json.go | 24 ++-- drivers/userdriver/userdriver.go | 28 ++--- drivers/userdriver/xmlud/xmlimport_test.go | 4 +- drivers/xlsx/ingest.go | 24 ++-- drivers/xlsx/xlsx.go | 16 +-- drivers/xlsx/xlsx_test.go | 2 +- libsq/core/lg/devlog/tint/buffer.go | 1 + libsq/core/lg/devlog/tint/handler.go | 1 - libsq/driver/driver.go | 24 ++-- libsq/driver/sources.go | 126 ++++++++++----------- libsq/pipeline.go | 8 +- libsq/query_no_src_test.go | 8 +- libsq/query_test.go | 8 +- testh/testh.go | 34 +++--- 27 files changed, 199 insertions(+), 203 deletions(-) diff --git a/cli/cmd_inspect.go b/cli/cmd_inspect.go index 142bd66b0..2a64d1806 100644 --- a/cli/cmd_inspect.go +++ b/cli/cmd_inspect.go @@ -125,7 +125,7 @@ func execInspect(cmd *cobra.Command, args []string) error { return err } - pool, err := ru.Pools.Open(ctx, src) + pool, err := ru.Sources.Open(ctx, src) if err != nil { return errz.Wrapf(err, "failed to inspect %s", src.Handle) } diff --git a/cli/cmd_slq.go b/cli/cmd_slq.go index ea2477b46..abe14a1d0 100644 --- a/cli/cmd_slq.go +++ b/cli/cmd_slq.go @@ -135,7 +135,7 @@ func execSLQInsert(ctx context.Context, ru *run.Run, mArgs map[string]string, ctx, cancelFn := context.WithCancel(ctx) defer cancelFn() - destPool, err := ru.Pools.Open(ctx, destSrc) + destPool, err := ru.Sources.Open(ctx, destSrc) if err != nil { return err } @@ -204,7 +204,7 @@ func execSLQPrint(ctx context.Context, ru *run.Run, mArgs map[string]string) err // // $ cat something.xlsx | sq @stdin.sheet1 func preprocessUserSLQ(ctx context.Context, ru *run.Run, args []string) (string, error) { - log, reg, pools, coll := lg.FromContext(ctx), ru.DriverRegistry, ru.Pools, ru.Config.Collection + log, reg, pools, coll := lg.FromContext(ctx), ru.DriverRegistry, ru.Sources, ru.Config.Collection activeSrc := coll.Active() if len(args) == 0 { diff --git a/cli/cmd_sql.go b/cli/cmd_sql.go index 8801eb634..c4933e2d5 100644 --- a/cli/cmd_sql.go +++ b/cli/cmd_sql.go @@ -120,7 +120,7 @@ func execSQL(cmd *cobra.Command, args []string) error { // to the configured writer. func execSQLPrint(ctx context.Context, ru *run.Run, fromSrc *source.Source) error { args := ru.Args - pool, err := ru.Pools.Open(ctx, fromSrc) + pool, err := ru.Sources.Open(ctx, fromSrc) if err != nil { return err } @@ -140,7 +140,7 @@ func execSQLInsert(ctx context.Context, ru *run.Run, fromSrc, destSrc *source.Source, destTbl string, ) error { args := ru.Args - pools := ru.Pools + pools := ru.Sources ctx, cancelFn := context.WithCancel(ctx) defer cancelFn() diff --git a/cli/cmd_tbl.go b/cli/cmd_tbl.go index 49d3bfefa..52ff32504 100644 --- a/cli/cmd_tbl.go +++ b/cli/cmd_tbl.go @@ -122,7 +122,7 @@ func execTblCopy(cmd *cobra.Command, args []string) error { } var pool driver.Pool - pool, err = ru.Pools.Open(ctx, tblHandles[0].src) + pool, err = ru.Sources.Open(ctx, tblHandles[0].src) if err != nil { return err } @@ -255,7 +255,7 @@ func execTblDrop(cmd *cobra.Command, args []string) (err error) { } var pool driver.Pool - if pool, err = ru.Pools.Open(ctx, tblH.src); err != nil { + if pool, err = ru.Sources.Open(ctx, tblH.src); err != nil { return err } diff --git a/cli/complete.go b/cli/complete.go index 2ab4656de..72d027283 100644 --- a/cli/complete.go +++ b/cli/complete.go @@ -140,7 +140,7 @@ func completeSLQ(cmd *cobra.Command, args []string, toComplete string) ([]string // completeDriverType is a completionFunc that suggests drivers. func completeDriverType(cmd *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { ru := getRun(cmd) - if ru.Pools == nil { + if ru.Sources == nil { if err := preRun(cmd, ru); err != nil { lg.Unexpected(logFrom(cmd), err) return nil, cobra.ShellCompDirectiveError @@ -383,7 +383,7 @@ func (c activeSchemaCompleter) complete(cmd *cobra.Command, args []string, toCom ctx, cancelFn := context.WithTimeout(cmd.Context(), OptShellCompletionTimeout.Get(ru.Config.Options)) defer cancelFn() - pool, err := ru.Pools.Open(ctx, src) + pool, err := ru.Sources.Open(ctx, src) if err != nil { lg.Unexpected(log, err) return nil, cobra.ShellCompDirectiveError @@ -759,7 +759,7 @@ func getTableNamesForHandle(ctx context.Context, ru *run.Run, handle string) ([] return nil, err } - pool, err := ru.Pools.Open(ctx, src) + pool, err := ru.Sources.Open(ctx, src) if err != nil { return nil, err } diff --git a/cli/diff/source.go b/cli/diff/source.go index f9e356c0c..caeb6e895 100644 --- a/cli/diff/source.go +++ b/cli/diff/source.go @@ -196,7 +196,7 @@ func fetchSourceMeta(ctx context.Context, ru *run.Run, handle string) (*source.S if err != nil { return nil, nil, err } - pool, err := ru.Pools.Open(ctx, src) + pool, err := ru.Sources.Open(ctx, src) if err != nil { return nil, nil, err } diff --git a/cli/diff/table.go b/cli/diff/table.go index 2b574c308..11a347086 100644 --- a/cli/diff/table.go +++ b/cli/diff/table.go @@ -114,7 +114,7 @@ func buildTableStructureDiff(cfg *Config, showRowCounts bool, td1, td2 *tableDat func fetchTableMeta(ctx context.Context, ru *run.Run, src *source.Source, table string) ( *metadata.Table, error, ) { - pool, err := ru.Pools.Open(ctx, src) + pool, err := ru.Sources.Open(ctx, src) if err != nil { return nil, err } diff --git a/cli/run.go b/cli/run.go index 6f5c1c5c3..1eaf8e3b1 100644 --- a/cli/run.go +++ b/cli/run.go @@ -157,19 +157,19 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { ru.DriverRegistry = driver.NewRegistry(log) dr := ru.DriverRegistry - ru.Pools = driver.NewPools(log, dr, ru.Files, scratchSrcFunc) - ru.Cleanup.AddC(ru.Pools) + ru.Sources = driver.NewSources(log, dr, ru.Files, scratchSrcFunc) + ru.Cleanup.AddC(ru.Sources) dr.AddProvider(sqlite3.Type, &sqlite3.Provider{Log: log}) dr.AddProvider(postgres.Type, &postgres.Provider{Log: log}) dr.AddProvider(sqlserver.Type, &sqlserver.Provider{Log: log}) dr.AddProvider(mysql.Type, &mysql.Provider{Log: log}) - csvp := &csv.Provider{Log: log, Scratcher: ru.Pools, Files: ru.Files} + csvp := &csv.Provider{Log: log, Ingester: ru.Sources, Files: ru.Files} dr.AddProvider(csv.TypeCSV, csvp) dr.AddProvider(csv.TypeTSV, csvp) ru.Files.AddDriverDetectors(csv.DetectCSV, csv.DetectTSV) - jsonp := &json.Provider{Log: log, Scratcher: ru.Pools, Files: ru.Files} + jsonp := &json.Provider{Log: log, Ingester: ru.Sources, Files: ru.Files} dr.AddProvider(json.TypeJSON, jsonp) dr.AddProvider(json.TypeJSONA, jsonp) dr.AddProvider(json.TypeJSONL, jsonp) @@ -180,7 +180,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { json.DetectJSONL(sampleSize), ) - dr.AddProvider(xlsx.Type, &xlsx.Provider{Log: log, Scratcher: ru.Pools, Files: ru.Files}) + dr.AddProvider(xlsx.Type, &xlsx.Provider{Log: log, Ingester: ru.Sources, Files: ru.Files}) ru.Files.AddDriverDetectors(xlsx.DetectXLSX) // One day we may have more supported user driver genres. userDriverImporters := map[string]userdriver.ImportFunc{ @@ -210,7 +210,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { Log: log, DriverDef: userDriverDef, ImportFn: importFn, - Scratcher: ru.Pools, + Ingester: ru.Sources, Files: ru.Files, } diff --git a/cli/run/run.go b/cli/run/run.go index ebc347645..5b5b01812 100644 --- a/cli/run/run.go +++ b/cli/run/run.go @@ -73,8 +73,8 @@ type Run struct { // Files manages file access. Files *source.Files - // Pools mediates access to db pools. - Pools *driver.Pools + // Sources mediates access to db pools. + Sources *driver.Sources // Writers holds the various writer types that // the CLI uses to print output. @@ -100,9 +100,9 @@ func (ru *Run) Close() error { func NewQueryContext(ru *Run, args map[string]string) *libsq.QueryContext { return &libsq.QueryContext{ Collection: ru.Config.Collection, - PoolOpener: ru.Pools, - JoinPoolOpener: ru.Pools, - ScratchPoolOpener: ru.Pools, + PoolOpener: ru.Sources, + JoinPoolOpener: ru.Sources, + ScratchPoolOpener: ru.Sources, Args: args, } } diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index 4960d1543..91b0eab5f 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -27,18 +27,18 @@ const ( // Provider implements driver.Provider. type Provider struct { - Log *slog.Logger - Scratcher driver.ScratchPoolOpener - Files *source.Files + Log *slog.Logger + Ingester driver.IngestOpener + Files *source.Files } // DriverFor implements driver.Provider. func (d *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { switch typ { //nolint:exhaustive case TypeCSV: - return &driveri{log: d.Log, typ: TypeCSV, scratcher: d.Scratcher, files: d.Files}, nil + return &driveri{log: d.Log, typ: TypeCSV, ingester: d.Ingester, files: d.Files}, nil case TypeTSV: - return &driveri{log: d.Log, typ: TypeTSV, scratcher: d.Scratcher, files: d.Files}, nil + return &driveri{log: d.Log, typ: TypeTSV, ingester: d.Ingester, files: d.Files}, nil } return nil, errz.Errorf("unsupported driver type {%s}", typ) @@ -46,10 +46,10 @@ func (d *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { // Driver implements driver.Driver. type driveri struct { - log *slog.Logger - typ drivertype.Type - scratcher driver.ScratchPoolOpener - files *source.Files + log *slog.Logger + typ drivertype.Type + ingester driver.IngestOpener + files *source.Files } // DriverMetadata implements driver.Driver. @@ -85,7 +85,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er } var err error - if p.impl, err = d.scratcher.OpenIngest(ctx, src, ingestFn, allowCache); err != nil { + if p.impl, err = d.ingester.OpenIngest(ctx, src, allowCache, ingestFn); err != nil { return nil, err } diff --git a/drivers/csv/ingest.go b/drivers/csv/ingest.go index 2fe15db08..d5cd53d0f 100644 --- a/drivers/csv/ingest.go +++ b/drivers/csv/ingest.go @@ -51,7 +51,7 @@ Possible values are: comma, space, pipe, tab, colon, semi, period.`, ) // ingestCSV loads the src CSV data into scratchDB. -func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFunc, scratchPool driver.Pool) error { +func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFunc, destPool driver.Pool) error { log := lg.FromContext(ctx) startUTC := time.Now().UTC() @@ -107,17 +107,17 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu // And now we need to create the dest table in scratchDB tblDef := createTblDef(source.MonotableName, header, kinds) - db, err := scratchPool.DB(ctx) + db, err := destPool.DB(ctx) if err != nil { return err } - err = scratchPool.SQLDriver().CreateTable(ctx, db, tblDef) + err = destPool.SQLDriver().CreateTable(ctx, db, tblDef) if err != nil { return errz.Wrap(err, "csv: failed to create dest scratch table") } - recMeta, err := getIngestRecMeta(ctx, scratchPool, tblDef) + recMeta, err := getIngestRecMeta(ctx, destPool, tblDef) if err != nil { return err } @@ -128,9 +128,9 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu insertWriter := libsq.NewDBWriter( libsq.MsgIngestRecords, - scratchPool, + destPool, tblDef.Name, - driver.OptTuningRecChanSize.Get(scratchPool.Source().Options), + driver.OptTuningRecChanSize.Get(destPool.Source().Options), ) err = execInsert(ctx, insertWriter, recMeta, mungers, recs, cr) if err != nil { @@ -145,7 +145,7 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu log.Debug("Inserted rows", lga.Count, inserted, lga.Elapsed, time.Since(startUTC).Round(time.Millisecond), - lga.Target, source.Target(scratchPool.Source(), tblDef.Name), + lga.Target, source.Target(destPool.Source(), tblDef.Name), ) return nil } diff --git a/drivers/csv/insert.go b/drivers/csv/insert.go index 70dd14e16..412cdba3d 100644 --- a/drivers/csv/insert.go +++ b/drivers/csv/insert.go @@ -118,13 +118,13 @@ func createTblDef(tblName string, colNames []string, kinds []kind.Kind) *sqlmode } // getIngestRecMeta returns record.Meta to use with RecordWriter.Open. -func getIngestRecMeta(ctx context.Context, scratchPool driver.Pool, tblDef *sqlmodel.TableDef) (record.Meta, error) { - db, err := scratchPool.DB(ctx) +func getIngestRecMeta(ctx context.Context, destPool driver.Pool, tblDef *sqlmodel.TableDef) (record.Meta, error) { + db, err := destPool.DB(ctx) if err != nil { return nil, err } - drvr := scratchPool.SQLDriver() + drvr := destPool.SQLDriver() colTypes, err := drvr.TableColumnTypes(ctx, db, tblDef.Name, tblDef.ColNames()) if err != nil { diff --git a/drivers/json/ingest.go b/drivers/json/ingest.go index 053c36ef8..2afc8e3d6 100644 --- a/drivers/json/ingest.go +++ b/drivers/json/ingest.go @@ -53,18 +53,18 @@ var ( ) // getRecMeta returns record.Meta to use with RecordWriter.Open. -func getRecMeta(ctx context.Context, scratchPool driver.Pool, tblDef *sqlmodel.TableDef) (record.Meta, error) { - db, err := scratchPool.DB(ctx) +func getRecMeta(ctx context.Context, pool driver.Pool, tblDef *sqlmodel.TableDef) (record.Meta, error) { + db, err := pool.DB(ctx) if err != nil { return nil, err } - colTypes, err := scratchPool.SQLDriver().TableColumnTypes(ctx, db, tblDef.Name, tblDef.ColNames()) + colTypes, err := pool.SQLDriver().TableColumnTypes(ctx, db, tblDef.Name, tblDef.ColNames()) if err != nil { return nil, err } - destMeta, _, err := scratchPool.SQLDriver().RecordMeta(ctx, colTypes) + destMeta, _, err := pool.SQLDriver().RecordMeta(ctx, colTypes) if err != nil { return nil, err } diff --git a/drivers/json/json.go b/drivers/json/json.go index 497236d04..df5a993a6 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -35,9 +35,9 @@ const ( // Provider implements driver.Provider. type Provider struct { - Log *slog.Logger - Scratcher driver.ScratchPoolOpener - Files *source.Files + Log *slog.Logger + Ingester driver.IngestOpener + Files *source.Files } // DriverFor implements driver.Provider. @@ -56,19 +56,19 @@ func (d *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { } return &driveri{ - typ: typ, - scratcher: d.Scratcher, - files: d.Files, - ingestFn: ingestFn, + typ: typ, + ingester: d.Ingester, + files: d.Files, + ingestFn: ingestFn, }, nil } // Driver implements driver.Driver. type driveri struct { - typ drivertype.Type - ingestFn ingestFunc - scratcher driver.ScratchPoolOpener - files *source.Files + typ drivertype.Type + ingestFn ingestFunc + ingester driver.IngestOpener + files *source.Files } // DriverMetadata implements driver.Driver. @@ -117,7 +117,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er } var err error - if p.impl, err = d.scratcher.OpenIngest(ctx, src, ingestFn, allowCache); err != nil { + if p.impl, err = d.ingester.OpenIngest(ctx, src, allowCache, ingestFn); err != nil { return nil, err } diff --git a/drivers/userdriver/userdriver.go b/drivers/userdriver/userdriver.go index cc22c6d1c..fef9e6abe 100644 --- a/drivers/userdriver/userdriver.go +++ b/drivers/userdriver/userdriver.go @@ -31,7 +31,7 @@ type ImportFunc func(ctx context.Context, def *DriverDef, type Provider struct { Log *slog.Logger DriverDef *DriverDef - Scratcher driver.ScratchPoolOpener + Ingester driver.IngestOpener Files *source.Files ImportFn ImportFunc } @@ -43,12 +43,12 @@ func (p *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { } return &driveri{ - log: p.Log, - typ: typ, - def: p.DriverDef, - scratcher: p.Scratcher, - ingestFn: p.ImportFn, - files: p.Files, + log: p.Log, + typ: typ, + def: p.DriverDef, + ingester: p.Ingester, + ingestFn: p.ImportFn, + files: p.Files, }, nil } @@ -62,12 +62,12 @@ func (p *Provider) Detectors() []source.DriverDetectFunc { // Driver implements driver.Driver. type driveri struct { - log *slog.Logger - typ drivertype.Type - def *DriverDef - files *source.Files - scratcher driver.ScratchPoolOpener - ingestFn ImportFunc + log *slog.Logger + typ drivertype.Type + def *DriverDef + files *source.Files + ingester driver.IngestOpener + ingestFn ImportFunc } // DriverMetadata implements driver.Driver. @@ -102,7 +102,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er } var err error - if p.impl, err = d.scratcher.OpenIngest(ctx, src, ingestFn, allowCache); err != nil { + if p.impl, err = d.ingester.OpenIngest(ctx, src, allowCache, ingestFn); err != nil { return nil, err } return p, nil diff --git a/drivers/userdriver/xmlud/xmlimport_test.go b/drivers/userdriver/xmlud/xmlimport_test.go index 607f7d966..cae6b6348 100644 --- a/drivers/userdriver/xmlud/xmlimport_test.go +++ b/drivers/userdriver/xmlud/xmlimport_test.go @@ -34,7 +34,7 @@ func TestImport_Ppl(t *testing.T) { require.Equal(t, xmlud.Genre, udDef.Genre) src := &source.Source{Handle: "@ppl_" + stringz.Uniq8(), Type: drivertype.None} - scratchDB, err := th.Pools().OpenScratchFor(th.Context, src) + scratchDB, err := th.Sources().OpenScratch(th.Context, src) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, scratchDB.Close()) @@ -80,7 +80,7 @@ func TestImport_RSS(t *testing.T) { require.Equal(t, xmlud.Genre, udDef.Genre) src := &source.Source{Handle: "@rss_" + stringz.Uniq8(), Type: drivertype.None} - scratchDB, err := th.Pools().OpenScratchFor(th.Context, src) + scratchDB, err := th.Sources().OpenScratch(th.Context, src) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, scratchDB.Close()) diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 60f027025..7e1676bd3 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -84,16 +84,16 @@ func (xs *xSheet) loadSampleRows(ctx context.Context, sampleSize int) error { return nil } -// ingestXLSX loads the data in xfile into scratchPool. +// ingestXLSX loads the data in xfile into destPool. // If includeSheetNames is non-empty, only the named sheets are ingested. -func ingestXLSX(ctx context.Context, src *source.Source, scratchPool driver.Pool, +func ingestXLSX(ctx context.Context, src *source.Source, destPool driver.Pool, xfile *excelize.File, includeSheetNames []string, ) error { log := lg.FromContext(ctx) start := time.Now() log.Debug("Beginning import from XLSX", lga.Src, src, - lga.Target, scratchPool.Source()) + lga.Target, destPool.Source()) var sheets []*xSheet if len(includeSheetNames) > 0 { @@ -124,18 +124,18 @@ func ingestXLSX(ctx context.Context, src *source.Source, scratchPool driver.Pool } var db *sql.DB - if db, err = scratchPool.DB(ctx); err != nil { + if db, err = destPool.DB(ctx); err != nil { return err } - if err = scratchPool.SQLDriver().CreateTable(ctx, db, sheetTbl.def); err != nil { + if err = destPool.SQLDriver().CreateTable(ctx, db, sheetTbl.def); err != nil { return err } } log.Debug("Tables created (but not yet populated)", lga.Count, len(sheetTbls), - lga.Target, scratchPool.Source(), + lga.Target, destPool.Source(), lga.Elapsed, time.Since(start)) var imported, skipped int @@ -146,7 +146,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, scratchPool driver.Pool continue } - if err = ingestSheetToTable(ctx, scratchPool, sheetTbls[i]); err != nil { + if err = ingestSheetToTable(ctx, destPool, sheetTbls[i]); err != nil { return err } imported++ @@ -156,7 +156,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, scratchPool driver.Pool lga.Count, imported, "skipped", skipped, lga.From, src, - lga.To, scratchPool.Source(), + lga.To, destPool.Source(), lga.Elapsed, time.Since(start), ) @@ -165,7 +165,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, scratchPool driver.Pool // ingestSheetToTable imports the sheet data into the appropriate table // in scratchPool. The scratch table must already exist. -func ingestSheetToTable(ctx context.Context, scratchPool driver.Pool, sheetTbl *sheetTable) error { +func ingestSheetToTable(ctx context.Context, destPool driver.Pool, sheetTbl *sheetTable) error { var ( log = lg.FromContext(ctx) startTime = time.Now() @@ -175,7 +175,7 @@ func ingestSheetToTable(ctx context.Context, scratchPool driver.Pool, sheetTbl * destColKinds = tblDef.ColKinds() ) - db, err := scratchPool.DB(ctx) + db, err := destPool.DB(ctx) if err != nil { return err } @@ -186,7 +186,7 @@ func ingestSheetToTable(ctx context.Context, scratchPool driver.Pool, sheetTbl * } defer lg.WarnIfCloseError(log, lgm.CloseDB, conn) - drvr := scratchPool.SQLDriver() + drvr := destPool.SQLDriver() batchSize := driver.MaxBatchRows(drvr, len(destColKinds)) bi, err := driver.NewBatchInsert( @@ -264,7 +264,7 @@ func ingestSheetToTable(ctx context.Context, scratchPool driver.Pool, sheetTbl * log.Debug("Inserted rows from sheet into table", lga.Count, bi.Written(), laSheet, sheet.name, - lga.Target, source.Target(scratchPool.Source(), tblDef.Name), + lga.Target, source.Target(destPool.Source(), tblDef.Name), lga.Elapsed, time.Since(startTime)) return nil diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index 163ac29a4..b7512de5d 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -27,9 +27,9 @@ const ( // Provider implements driver.Provider. type Provider struct { - Log *slog.Logger - Files *source.Files - Scratcher driver.ScratchPoolOpener + Log *slog.Logger + Files *source.Files + Ingester driver.IngestOpener } // DriverFor implements driver.Provider. @@ -38,14 +38,14 @@ func (p *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { return nil, errz.Errorf("unsupported driver type {%s}", typ) } - return &Driver{log: p.Log, scratcher: p.Scratcher, files: p.Files}, nil + return &Driver{log: p.Log, ingester: p.Ingester, files: p.Files}, nil } // Driver implements driver.Driver. type Driver struct { - log *slog.Logger - scratcher driver.ScratchPoolOpener - files *source.Files + log *slog.Logger + ingester driver.IngestOpener + files *source.Files } // DriverMetadata implements driver.Driver. @@ -92,7 +92,7 @@ func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Pool, err } var err error - if p.backingPool, err = d.scratcher.OpenIngest(ctx, p.src, ingestFn, allowCache); err != nil { + if p.backingPool, err = d.ingester.OpenIngest(ctx, p.src, allowCache, ingestFn); err != nil { return nil, err } diff --git a/drivers/xlsx/xlsx_test.go b/drivers/xlsx/xlsx_test.go index 11635fb37..af4088158 100644 --- a/drivers/xlsx/xlsx_test.go +++ b/drivers/xlsx/xlsx_test.go @@ -162,7 +162,7 @@ func TestOpenFileFormats(t *testing.T) { Location: filepath.Join("testdata", "file_formats", tc.filename), }) - pool, err := th.Pools().Open(th.Context, src) + pool, err := th.Sources().Open(th.Context, src) require.NoError(t, err) db, err := pool.DB(th.Context) if tc.wantErr { diff --git a/libsq/core/lg/devlog/tint/buffer.go b/libsq/core/lg/devlog/tint/buffer.go index 4d7321a6c..178aea7a8 100644 --- a/libsq/core/lg/devlog/tint/buffer.go +++ b/libsq/core/lg/devlog/tint/buffer.go @@ -23,6 +23,7 @@ func (b *buffer) Free() { bufPool.Put(b) } } + func (b *buffer) Write(bytes []byte) (int, error) { *b = append(*b, bytes...) return len(bytes), nil diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 89ad8ad00..3c2572e48 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -337,7 +337,6 @@ func (h *handler) appendSource(buf *buffer, src *slog.Source) { buf.WriteByte(':') buf.WriteString(strconv.Itoa(src.Line)) buf.WriteStringIf(!h.noColor, ansiReset) - } func (h *handler) appendAttr(buf *buffer, attr slog.Attr, groupsPrefix string, groups []string) { diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index cea51c14f..48e7cc934 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -169,20 +169,20 @@ type JoinPoolOpener interface { } // ScratchPoolOpener opens a scratch database pool. A scratch database is -// typically a short-lived database used as a target for loading -// non-SQL data (such as CSV). +// a short-lived database used for ephemeral purposes. type ScratchPoolOpener interface { - // OpenScratchFor returns a pool for scratch use. - OpenScratchFor(ctx context.Context, src *source.Source) (Pool, error) - - // OpenCachedFor returns any already cached ingested pool for src. - // If no such cache, or if it's expired, false is returned. - OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bool, error) + // OpenScratch returns a pool for scratch use. + OpenScratch(ctx context.Context, src *source.Source) (Pool, error) +} - // OpenIngest opens a pool for src by executing ingestFn. If allowCache - // is false, ingest always occurs; if true, the cache is consulted first. - OpenIngest(ctx context.Context, src *source.Source, - ingestFn func(ctx context.Context, destPool Pool) error, allowCache bool) (Pool, error) +// IngestOpener opens a pool for ingest use. +type IngestOpener interface { + // OpenIngest opens a pool for src by executing ingestFn, which is + // responsible for ingesting data into dest. If allowCache is false, + // ingest always occurs; if true, the cache is consulted first (and + // ingestFn may not be invoked). + OpenIngest(ctx context.Context, src *source.Source, allowCache bool, + ingestFn func(ctx context.Context, dest Pool) error) (Pool, error) } // Driver is the core interface that must be implemented for each type diff --git a/libsq/driver/sources.go b/libsq/driver/sources.go index d4054caec..919f9365b 100644 --- a/libsq/driver/sources.go +++ b/libsq/driver/sources.go @@ -24,21 +24,19 @@ import ( ) var ( - _ PoolOpener = (*Pools)(nil) - _ JoinPoolOpener = (*Pools)(nil) - _ ScratchPoolOpener = (*Pools)(nil) + _ PoolOpener = (*Sources)(nil) + _ JoinPoolOpener = (*Sources)(nil) + _ ScratchPoolOpener = (*Sources)(nil) ) // ScratchSrcFunc is a function that returns a scratch source. // The caller is responsible for invoking cleanFn. type ScratchSrcFunc func(ctx context.Context, name string) (src *source.Source, cleanFn func() error, err error) -// Pools provides a mechanism for getting Pool instances. +// Sources provides a mechanism for getting Pool instances. // Note that at this time instances returned by Open are cached // and then closed by Close. This may be a bad approach. -// -// FIXME: Why not rename driver.Pools to driver.Sources? -type Pools struct { +type Sources struct { log *slog.Logger drvrs Provider mu sync.Mutex @@ -48,11 +46,11 @@ type Pools struct { clnup *cleanup.Cleanup } -// NewPools returns a Pools instances. -func NewPools(log *slog.Logger, drvrs Provider, +// NewSources returns a Sources instances. +func NewSources(log *slog.Logger, drvrs Provider, files *source.Files, scratchSrcFn ScratchSrcFunc, -) *Pools { - return &Pools{ +) *Sources { + return &Sources{ log: log, drvrs: drvrs, mu: sync.Mutex{}, @@ -73,23 +71,23 @@ func NewPools(log *slog.Logger, drvrs Provider, // and needs to be revisited. // // Open implements PoolOpener. -func (d *Pools) Open(ctx context.Context, src *source.Source) (Pool, error) { +func (ss *Sources) Open(ctx context.Context, src *source.Source) (Pool, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) - d.mu.Lock() - defer d.mu.Unlock() - return d.doOpen(ctx, src) + ss.mu.Lock() + defer ss.mu.Unlock() + return ss.doOpen(ctx, src) } -func (d *Pools) doOpen(ctx context.Context, src *source.Source) (Pool, error) { +func (ss *Sources) doOpen(ctx context.Context, src *source.Source) (Pool, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) key := src.Handle + "_" + src.Hash() - pool, ok := d.pools[key] + pool, ok := ss.pools[key] if ok { return pool, nil } - drvr, err := d.drvrs.DriverFor(src.Type) + drvr, err := ss.drvrs.DriverFor(src.Type) if err != nil { return nil, err } @@ -103,48 +101,42 @@ func (d *Pools) doOpen(ctx context.Context, src *source.Source) (Pool, error) { return nil, err } - d.clnup.AddC(pool) + ss.clnup.AddC(pool) - d.pools[key] = pool + ss.pools[key] = pool return pool, nil } -// OpenScratchFor returns a scratch database instance. It is not +// OpenScratch returns a scratch database instance. It is not // necessary for the caller to close the returned Pool as // its Close method will be invoked by d.Close. // -// OpenScratchFor implements ScratchPoolOpener. -// -// REVISIT: do we really need to pass a source here? Just a string should do. -// -// FIXME: the problem is with passing src? -// -// FIXME: Add cacheAllowed bool? -func (d *Pools) OpenScratchFor(ctx context.Context, src *source.Source) (Pool, error) { +// OpenScratch implements ScratchPoolOpener. +func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Pool, error) { const msgCloseScratch = "Close scratch db" - _, srcCacheDBFilepath, _, err := d.getCachePaths(src) + _, srcCacheDBFilepath, _, err := ss.getCachePaths(src) if err != nil { return nil, err } - scratchSrc, cleanFn, err := d.scratchSrcFn(ctx, srcCacheDBFilepath) + scratchSrc, cleanFn, err := ss.scratchSrcFn(ctx, srcCacheDBFilepath) if err != nil { // if err is non-nil, cleanup is guaranteed to be nil return nil, err } - d.log.Debug("Opening scratch src", lga.Src, scratchSrc) + ss.log.Debug("Opening scratch src", lga.Src, scratchSrc) - backingDrvr, err := d.drvrs.DriverFor(scratchSrc.Type) + backingDrvr, err := ss.drvrs.DriverFor(scratchSrc.Type) if err != nil { - lg.WarnIfFuncError(d.log, msgCloseScratch, cleanFn) + lg.WarnIfFuncError(ss.log, msgCloseScratch, cleanFn) return nil, err } var backingPool Pool backingPool, err = backingDrvr.Open(ctx, scratchSrc) if err != nil { - lg.WarnIfFuncError(d.log, msgCloseScratch, cleanFn) + lg.WarnIfFuncError(ss.log, msgCloseScratch, cleanFn) return nil, err } @@ -152,29 +144,29 @@ func (d *Pools) OpenScratchFor(ctx context.Context, src *source.Source) (Pool, e if !allowCache { // If the ingest cache is disabled, we add the cleanup func // so the scratch DB is deleted when the session ends. - d.clnup.AddE(cleanFn) + ss.clnup.AddE(cleanFn) } return backingPool, nil } // OpenIngest implements driver.ScratchPoolOpener. -func (d *Pools) OpenIngest(ctx context.Context, src *source.Source, - ingestFn func(ctx context.Context, destPool Pool) error, allowCache bool, +func (ss *Sources) OpenIngest(ctx context.Context, src *source.Source, allowCache bool, + ingestFn func(ctx context.Context, dest Pool) error, ) (Pool, error) { if !allowCache || src.Handle == source.StdinHandle { // We don't currently cache stdin. - return d.openIngestNoCache(ctx, src, ingestFn) + return ss.openIngestNoCache(ctx, src, ingestFn) } - return d.openIngestCache(ctx, src, ingestFn) + return ss.openIngestCache(ctx, src, ingestFn) } -func (d *Pools) openIngestNoCache(ctx context.Context, src *source.Source, +func (ss *Sources) openIngestNoCache(ctx context.Context, src *source.Source, ingestFn func(ctx context.Context, destPool Pool) error, ) (Pool, error) { log := lg.FromContext(ctx) - impl, err := d.OpenScratchFor(ctx, src) + impl, err := ss.OpenScratch(ctx, src) if err != nil { return nil, err } @@ -191,18 +183,18 @@ func (d *Pools) openIngestNoCache(ctx context.Context, src *source.Source, lg.WarnIfCloseError(log, lgm.CloseDB, impl) } - d.log.Debug("Ingest completed", + ss.log.Debug("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) return impl, nil } -func (d *Pools) openIngestCache(ctx context.Context, src *source.Source, +func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, ingestFn func(ctx context.Context, destPool Pool) error, ) (Pool, error) { log := lg.FromContext(ctx) - lock, err := d.acquireLock(ctx, src) + lock, err := ss.acquireLock(ctx, src) if err != nil { return nil, err } @@ -215,14 +207,14 @@ func (d *Pools) openIngestCache(ctx context.Context, src *source.Source, } }() - cacheDir, _, checksumsPath, err := d.getCachePaths(src) + cacheDir, _, checksumsPath, err := ss.getCachePaths(src) if err != nil { return nil, err } log.Debug("Using cache dir", lga.Path, cacheDir) - ingestFilePath, err := d.files.Filepath(ctx, src) + ingestFilePath, err := ss.files.Filepath(ctx, src) if err != nil { return nil, err } @@ -231,7 +223,7 @@ func (d *Pools) openIngestCache(ctx context.Context, src *source.Source, impl Pool foundCached bool ) - if impl, foundCached, err = d.OpenCachedFor(ctx, src); err != nil { + if impl, foundCached, err = ss.OpenCachedFor(ctx, src); err != nil { return nil, err } if foundCached { @@ -243,7 +235,7 @@ func (d *Pools) openIngestCache(ctx context.Context, src *source.Source, log.Debug("Ingest cache MISS: no cache for source", lga.Src, src) - impl, err = d.OpenScratchFor(ctx, src) + impl, err = ss.OpenScratch(ctx, src) if err != nil { return nil, err } @@ -282,7 +274,7 @@ func (d *Pools) openIngestCache(ctx context.Context, src *source.Source, // getCachePaths returns the paths to the cache files for src. // There is no guarantee that these files exist, or are accessible. // It's just the paths. -func (d *Pools) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checksums string, err error) { +func (ss *Sources) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checksums string, err error) { if srcCacheDir, err = source.CacheDirFor(src); err != nil { return "", "", "", err } @@ -298,8 +290,8 @@ func (d *Pools) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checksu // defer lg.WarnIfFuncError(d.log, "failed to unlock cache lock", lock.Unlock) // // The lock acquisition process is retried with backoff. -func (d *Pools) acquireLock(ctx context.Context, src *source.Source) (lockfile.Lockfile, error) { - lock, err := d.getLockfileFor(src) +func (ss *Sources) acquireLock(ctx context.Context, src *source.Source) (lockfile.Lockfile, error) { + lock, err := ss.getLockfileFor(src) if err != nil { return "", err } @@ -321,8 +313,8 @@ func (d *Pools) acquireLock(ctx context.Context, src *source.Source) (lockfile.L // getLockfileFor returns a lockfile for src. It doesn't // actually acquire the lock. -func (d *Pools) getLockfileFor(src *source.Source) (lockfile.Lockfile, error) { - srcCacheDir, _, _, err := d.getCachePaths(src) +func (ss *Sources) getLockfileFor(src *source.Source) (lockfile.Lockfile, error) { + srcCacheDir, _, _, err := ss.getCachePaths(src) if err != nil { return "", err } @@ -335,8 +327,8 @@ func (d *Pools) getLockfileFor(src *source.Source) (lockfile.Lockfile, error) { } // OpenCachedFor implements ScratchPoolOpener. -func (d *Pools) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bool, error) { - _, cacheDBPath, checksumsPath, err := d.getCachePaths(src) +func (ss *Sources) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bool, error) { + _, cacheDBPath, checksumsPath, err := ss.getCachePaths(src) if err != nil { return nil, false, err } @@ -350,7 +342,7 @@ func (d *Pools) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bo return nil, false, err } - drvr, err := d.drvrs.DriverFor(src.Type) + drvr, err := ss.drvrs.DriverFor(src.Type) if err != nil { return nil, false, err } @@ -360,11 +352,11 @@ func (d *Pools) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bo src.Handle, src.Type) } - srcFilepath, err := d.files.Filepath(ctx, src) + srcFilepath, err := ss.files.Filepath(ctx, src) if err != nil { return nil, false, err } - d.log.Debug("Got srcFilepath for src", + ss.log.Debug("Got srcFilepath for src", lga.Src, src, lga.Path, srcFilepath) cachedChecksum, ok := mChecksums[srcFilepath] @@ -387,7 +379,7 @@ func (d *Pools) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bo return nil, false, nil } - backingType, err := d.files.DriverType(ctx, cacheDBPath) + backingType, err := ss.files.DriverType(ctx, cacheDBPath) if err != nil { return nil, false, err } @@ -398,7 +390,7 @@ func (d *Pools) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bo Type: backingType, } - backingPool, err := d.doOpen(ctx, backingSrc) + backingPool, err := ss.doOpen(ctx, backingSrc) if err != nil { return nil, false, errz.Wrapf(err, "open cached DB for source %s", src.Handle) } @@ -417,18 +409,18 @@ func (d *Pools) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bo // to OpenScratch. // // OpenJoin implements JoinPoolOpener. -func (d *Pools) OpenJoin(ctx context.Context, srcs ...*source.Source) (Pool, error) { +func (ss *Sources) OpenJoin(ctx context.Context, srcs ...*source.Source) (Pool, error) { var names []string for _, src := range srcs { names = append(names, src.Handle[1:]) } - d.log.Debug("OpenJoin", "sources", strings.Join(names, ",")) - return d.OpenScratchFor(ctx, srcs[0]) + ss.log.Debug("OpenJoin", "sources", strings.Join(names, ",")) + return ss.OpenScratch(ctx, srcs[0]) } // Close closes d, invoking Close on any instances opened via d.Open. -func (d *Pools) Close() error { - d.log.Debug("Closing databases(s)...", lga.Count, d.clnup.Len()) - return d.clnup.Run() +func (ss *Sources) Close() error { + ss.log.Debug("Closing databases(s)...", lga.Count, ss.clnup.Len()) + return ss.clnup.Run() } diff --git a/libsq/pipeline.go b/libsq/pipeline.go index 126298900..df9a1e4d6 100644 --- a/libsq/pipeline.go +++ b/libsq/pipeline.go @@ -187,8 +187,12 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { if src = p.qc.Collection.Active(); src == nil { log.Debug("No active source, will use scratchdb.") // REVISIT: ScratchPoolOpener needs a source, so we just make one up. - ephemeralSrc := &source.Source{Type: drivertype.None, Handle: "@scratch" + stringz.Uniq8()} - p.targetPool, err = p.qc.ScratchPoolOpener.OpenScratchFor(ctx, ephemeralSrc) + ephemeralSrc := &source.Source{ + Type: drivertype.None, + Handle: "@scratch" + stringz.Uniq8(), + Ephemeral: true, + } + p.targetPool, err = p.qc.ScratchPoolOpener.OpenScratch(ctx, ephemeralSrc) if err != nil { return err } diff --git a/libsq/query_no_src_test.go b/libsq/query_no_src_test.go index 6662c0396..03b76ae8e 100644 --- a/libsq/query_no_src_test.go +++ b/libsq/query_no_src_test.go @@ -30,13 +30,13 @@ func TestQuery_no_source(t *testing.T) { t.Logf("\nquery: %s\n want: %s", tc.in, tc.want) th := testh.New(t) coll := th.NewCollection() - pools := th.Pools() + sources := th.Sources() qc := &libsq.QueryContext{ Collection: coll, - PoolOpener: pools, - JoinPoolOpener: pools, - ScratchPoolOpener: pools, + PoolOpener: sources, + JoinPoolOpener: sources, + ScratchPoolOpener: sources, } gotSQL, gotErr := libsq.SLQ2SQL(th.Context, qc, tc.in) diff --git a/libsq/query_test.go b/libsq/query_test.go index 4b78e0cff..4d41ac4aa 100644 --- a/libsq/query_test.go +++ b/libsq/query_test.go @@ -163,13 +163,13 @@ func doExecQueryTestCase(t *testing.T, tc queryTestCase) { require.NoError(t, err) th := testh.New(t) - pools := th.Pools() + sources := th.Sources() qc := &libsq.QueryContext{ Collection: coll, - PoolOpener: pools, - JoinPoolOpener: pools, - ScratchPoolOpener: pools, + PoolOpener: sources, + JoinPoolOpener: sources, + ScratchPoolOpener: sources, Args: tc.args, } diff --git a/testh/testh.go b/testh/testh.go index 1d5d9b7e1..3caaf458c 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -99,7 +99,7 @@ type Helper struct { registry *driver.Registry files *source.Files - pools *driver.Pools + sources *driver.Sources run *run.Run initOnce sync.Once @@ -175,20 +175,20 @@ func (h *Helper) init() { h.files.AddDriverDetectors(source.DetectMagicNumber) - h.pools = driver.NewPools(log, h.registry, h.files, sqlite3.NewScratchSource) - h.Cleanup.AddC(h.pools) + h.sources = driver.NewSources(log, h.registry, h.files, sqlite3.NewScratchSource) + h.Cleanup.AddC(h.sources) h.registry.AddProvider(sqlite3.Type, &sqlite3.Provider{Log: log}) h.registry.AddProvider(postgres.Type, &postgres.Provider{Log: log}) h.registry.AddProvider(sqlserver.Type, &sqlserver.Provider{Log: log}) h.registry.AddProvider(mysql.Type, &mysql.Provider{Log: log}) - csvp := &csv.Provider{Log: log, Scratcher: h.pools, Files: h.files} + csvp := &csv.Provider{Log: log, Ingester: h.sources, Files: h.files} h.registry.AddProvider(csv.TypeCSV, csvp) h.registry.AddProvider(csv.TypeTSV, csvp) h.files.AddDriverDetectors(csv.DetectCSV, csv.DetectTSV) - jsonp := &json.Provider{Log: log, Scratcher: h.pools, Files: h.files} + jsonp := &json.Provider{Log: log, Ingester: h.sources, Files: h.files} h.registry.AddProvider(json.TypeJSON, jsonp) h.registry.AddProvider(json.TypeJSONA, jsonp) h.registry.AddProvider(json.TypeJSONL, jsonp) @@ -198,7 +198,7 @@ func (h *Helper) init() { json.DetectJSONL(driver.OptIngestSampleSize.Get(nil)), ) - h.registry.AddProvider(xlsx.Type, &xlsx.Provider{Log: log, Scratcher: h.pools, Files: h.files}) + h.registry.AddProvider(xlsx.Type, &xlsx.Provider{Log: log, Ingester: h.sources, Files: h.files}) h.files.AddDriverDetectors(xlsx.DetectXLSX) h.addUserDrivers() @@ -366,7 +366,7 @@ func (h *Helper) NewCollection(handles ...string) *source.Collection { return coll } -// Open opens a driver.Pool for src via h's internal Pools +// Open opens a driver.Pool for src via h's internal Sources // instance: thus subsequent calls to Open may return the // same Pool instance. The opened driver.Pool will be closed // during h.Close. @@ -374,7 +374,7 @@ func (h *Helper) Open(src *source.Source) driver.Pool { ctx, cancelFn := context.WithTimeout(h.Context, h.dbOpenTimeout) defer cancelFn() - pool, err := h.Pools().Open(ctx, src) + pool, err := h.Sources().Open(ctx, src) require.NoError(h.T, err) db, err := pool.DB(ctx) @@ -630,9 +630,9 @@ func (h *Helper) QuerySLQ(query string, args map[string]string) (*RecordSink, er qc := &libsq.QueryContext{ Collection: h.coll, - PoolOpener: h.pools, - JoinPoolOpener: h.pools, - ScratchPoolOpener: h.pools, + PoolOpener: h.sources, + JoinPoolOpener: h.sources, + ScratchPoolOpener: h.sources, Args: args, } @@ -740,7 +740,7 @@ func (h *Helper) addUserDrivers() { Log: h.Log, DriverDef: userDriverDef, ImportFn: importFn, - Scratcher: h.pools, + Ingester: h.sources, Files: h.files, } @@ -754,10 +754,10 @@ func (h *Helper) IsMonotable(src *source.Source) bool { return h.DriverFor(src).DriverMetadata().Monotable } -// Pools returns the helper's driver.Pools instance. -func (h *Helper) Pools() *driver.Pools { +// Sources returns the helper's driver.Sources instance. +func (h *Helper) Sources() *driver.Sources { h.init() - return h.pools + return h.sources } // Files returns the helper's Files instance. @@ -768,7 +768,7 @@ func (h *Helper) Files() *source.Files { // SourceMetadata returns metadata for src. func (h *Helper) SourceMetadata(src *source.Source) (*metadata.Source, error) { - pools, err := h.Pools().Open(h.Context, src) + pools, err := h.Sources().Open(h.Context, src) if err != nil { return nil, err } @@ -778,7 +778,7 @@ func (h *Helper) SourceMetadata(src *source.Source) (*metadata.Source, error) { // TableMetadata returns metadata for src's table. func (h *Helper) TableMetadata(src *source.Source, tbl string) (*metadata.Table, error) { - pools, err := h.Pools().Open(h.Context, src) + pools, err := h.Sources().Open(h.Context, src) if err != nil { return nil, err } From d943ca8333ec4477dc92efafff41817902335efb Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 08:28:29 -0700 Subject: [PATCH 038/195] Cleanup --- libsq/core/progress/progress.go | 10 ++++++---- libsq/driver/sources.go | 2 +- libsq/source/files.go | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 325a935c9..7e46fc13e 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -3,7 +3,6 @@ package progress import ( "context" "io" - "os" "sync" "time" @@ -83,11 +82,10 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors p.pcInit = func() { opts := []mpb.ContainerOption{ - mpb.WithDebugOutput(os.Stdout), mpb.WithOutput(out), mpb.WithWidth(boxWidth), - // mpb.WithRefreshRate(refreshRate), - // mpb.WithAutoRefresh(), // Needed for color in Windows, apparently + mpb.WithRefreshRate(refreshRate), + mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } if delay > 0 { delayCh := renderDelay(ctx, delay) @@ -159,6 +157,10 @@ func (p *Progress) Wait() { bar.bar.Abort(true) } + for _, bar := range p.bars { + bar.bar.Wait() + } + p.pc.Wait() } diff --git a/libsq/driver/sources.go b/libsq/driver/sources.go index 919f9365b..ad1872d69 100644 --- a/libsq/driver/sources.go +++ b/libsq/driver/sources.go @@ -155,7 +155,7 @@ func (ss *Sources) OpenIngest(ctx context.Context, src *source.Source, allowCach ingestFn func(ctx context.Context, dest Pool) error, ) (Pool, error) { if !allowCache || src.Handle == source.StdinHandle { - // We don't currently cache stdin. + // We don't currently cache stdin. Probably we never will? return ss.openIngestNoCache(ctx, src, ingestFn) } diff --git a/libsq/source/files.go b/libsq/source/files.go index 0100bc5ba..a71dbb74d 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -173,8 +173,8 @@ func (fs *Files) addStdin(ctx context.Context, f *os.File) error { lw := ioz.NewWrittenWriter(w) fs.stdinLength = lw.Written - df := ioz.DelayReader(f, time.Microsecond*500, true) // FIXME: Delete - cr := contextio.NewReader(ctx, df) + //df := ioz.DelayReader(f, time.Microsecond*500, true) // FIXME: Delete + cr := contextio.NewReader(ctx, f) pw := progress.NewWriter(ctx, "Reading stdin", -1, lw) start := time.Now() From 12bfac5ef6254ce144992c2885e220cbdcdb89a8 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 08:30:38 -0700 Subject: [PATCH 039/195] go.mod now uses github.com/neilotoole/fscache@develop --- go.mod | 2 +- go.sum | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 6b1717746..4274f061e 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/muesli/mango-cobra v1.2.0 github.com/muesli/roff v0.1.0 github.com/ncruces/go-strftime v0.1.9 - github.com/neilotoole/fscache v0.0.0-20231126193352-f79ad8b71381 + github.com/neilotoole/fscache v0.0.0-20231203152859-7f9a76169e92 github.com/neilotoole/shelleditor v0.4.1 github.com/neilotoole/slogt v1.1.0 github.com/nightlyone/lockfile v1.0.0 diff --git a/go.sum b/go.sum index 0215ae65a..96eba52fa 100644 --- a/go.sum +++ b/go.sum @@ -134,6 +134,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/neilotoole/fscache v0.0.0-20231126193352-f79ad8b71381 h1:yq4OXuvSTMzvCm2m9FlpUnL8PbVQGW4qv+s8uRRtxK8= github.com/neilotoole/fscache v0.0.0-20231126193352-f79ad8b71381/go.mod h1:GdelWtWN0Gbf2uE+rEZ4GinWtfV6PobHgdrQ4IrB504= +github.com/neilotoole/fscache v0.0.0-20231203152859-7f9a76169e92 h1:VyBPnrrbiTiN0IlN6ULtQ+tGVIQLCbEtJPpL7Cv/Jog= +github.com/neilotoole/fscache v0.0.0-20231203152859-7f9a76169e92/go.mod h1:NnvBCFGuEGIP73eSn+XbDluSR0aj61e0P1cIaIEOE04= github.com/neilotoole/shelleditor v0.4.1 h1:74LEw2mVo3jtNw2BjII6RSss9DXgEqAbmCQDDiJvzO0= github.com/neilotoole/shelleditor v0.4.1/go.mod h1:QanOZN4syDMp/L0SKwZb47Mh49mvLWX3ja5YfbYDDjo= github.com/neilotoole/slogt v1.1.0 h1:c7qE92sq+V0yvCuaxph+RQ2jOKL61c4hqS1Bv9W7FZE= @@ -185,6 +187,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= From ec3d1874371a694ecea82e52670e3faf9cb26b76 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 08:41:27 -0700 Subject: [PATCH 040/195] go.mod updates --- go.mod | 38 +++++++++++++++--------------- go.sum | 73 +++++++++++++++++++++++++--------------------------------- 2 files changed, 50 insertions(+), 61 deletions(-) diff --git a/go.mod b/go.mod index 4274f061e..b5abd90ed 100644 --- a/go.mod +++ b/go.mod @@ -32,30 +32,30 @@ require ( github.com/nightlyone/lockfile v1.0.0 github.com/otiai10/copy v1.14.0 github.com/ryboe/q v1.0.20 - github.com/samber/lo v1.38.1 - github.com/segmentio/encoding v0.3.6 + github.com/samber/lo v1.39.0 + github.com/segmentio/encoding v0.3.7 github.com/sethvargo/go-retry v0.2.4 github.com/shopspring/decimal v1.3.1 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 - github.com/vbauerster/mpb/v8 v8.6.2 - github.com/xo/dburl v0.18.2 + github.com/vbauerster/mpb/v8 v8.7.0 + github.com/xo/dburl v0.19.1 github.com/xuri/excelize/v2 v2.8.0 go.uber.org/atomic v1.11.0 go.uber.org/multierr v1.11.0 - golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa + golang.org/x/exp v0.0.0-20231127185646-65229373498e golang.org/x/mod v0.14.0 - golang.org/x/net v0.18.0 + golang.org/x/net v0.19.0 golang.org/x/sync v0.5.0 - golang.org/x/term v0.14.0 + golang.org/x/term v0.15.0 golang.org/x/text v0.14.0 ) require ( github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/Masterminds/goutils v1.1.1 // indirect - github.com/Masterminds/semver/v3 v3.2.1 // indirect + github.com/Masterminds/semver/v3 v3.2.0 // indirect github.com/VividCortex/ewma v1.2.0 // indirect github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -63,19 +63,19 @@ require ( github.com/djherbis/stream v1.4.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect - github.com/huandu/xstrings v1.4.0 // indirect - github.com/imdario/mergo v0.3.16 // indirect + github.com/huandu/xstrings v1.3.3 // indirect + github.com/imdario/mergo v0.3.11 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect - github.com/mitchellh/copystructure v1.2.0 // indirect - github.com/mitchellh/reflectwalk v1.0.2 // indirect + github.com/mitchellh/copystructure v1.0.0 // indirect + github.com/mitchellh/reflectwalk v1.0.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect - github.com/muesli/mango v0.2.0 // indirect + github.com/muesli/mango v0.1.0 // indirect github.com/muesli/mango-pflag v0.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/richardlehane/mscfb v1.0.4 // indirect @@ -83,11 +83,11 @@ require ( github.com/rivo/uniseg v0.4.4 // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/segmentio/asm v1.1.3 // indirect - github.com/spf13/cast v1.5.1 // indirect - github.com/xuri/efp v0.0.0-20231025114914-d1ff6096ae53 // indirect - github.com/xuri/nfp v0.0.0-20230919160717-d98342af3f05 // indirect - golang.org/x/crypto v0.15.0 // indirect - golang.org/x/sys v0.14.0 // indirect - golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect + github.com/spf13/cast v1.3.1 // indirect + github.com/xuri/efp v0.0.0-20230802181842-ad255f2331ca // indirect + github.com/xuri/nfp v0.0.0-20230819163627-dc951e3ffe1a // indirect + golang.org/x/crypto v0.16.0 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 96eba52fa..89d38cb45 100644 --- a/go.sum +++ b/go.sum @@ -14,9 +14,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0 h1:HCc0+LpPfpC github.com/AzureAD/microsoft-authentication-library-for-go v1.1.0/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= +github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7YgDP83g= github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= -github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= -github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1ow= @@ -48,8 +47,6 @@ github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= -github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0XL9UY= -github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= @@ -73,12 +70,10 @@ github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/h2non/filetype v1.1.3 h1:FKkx9QbD7HR/zjK1Ia5XiBsq9zdLi5Kf3zGyFTAFkGg= github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy1HndBY= +github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= -github.com/huandu/xstrings v1.4.0 h1:D17IlohoQq4UcpqD7fDk80P7l+lwAmlFaBHgOipl2FU= -github.com/huandu/xstrings v1.4.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA= github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= -github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= -github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -110,20 +105,18 @@ github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+ github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/microsoft/go-mssqldb v1.6.0 h1:mM3gYdVwEPFrlg/Dvr2DNVEgYFG7L42l+dGc67NNNpc= github.com/microsoft/go-mssqldb v1.6.0/go.mod h1:00mDtPbeQCRGC1HwOOR5K/gr30P1NcEG0vx6Kbv2aJU= +github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= -github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= -github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/go-wordwrap v1.0.1 h1:TLuKupo69TCn6TQSyGxwI1EblZZEsQ0vMlAFQflz0v0= github.com/mitchellh/go-wordwrap v1.0.1/go.mod h1:R62XHJLzvMFRBbcrT7m7WgmE1eOyTSsCt+hzestvNj0= +github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY= github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= -github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= -github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= -github.com/muesli/mango v0.2.0 h1:iNNc0c5VLQ6fsMgAqGQofByNUBH2Q2nEbD6TaI+5yyQ= -github.com/muesli/mango v0.2.0/go.mod h1:5XFpbC8jY5UUv89YQciiXNlbi+iJgt29VDC5xbzrLL4= +github.com/muesli/mango v0.1.0 h1:DZQK45d2gGbql1arsYA4vfg4d7I9Hfx5rX/GCmzsAvI= +github.com/muesli/mango v0.1.0/go.mod h1:5XFpbC8jY5UUv89YQciiXNlbi+iJgt29VDC5xbzrLL4= github.com/muesli/mango-cobra v1.2.0 h1:DQvjzAM0PMZr85Iv9LIMaYISpTOliMEg+uMFtNbYvWg= github.com/muesli/mango-cobra v1.2.0/go.mod h1:vMJL54QytZAJhCT13LPVDfkvCUJ5/4jNUKF/8NC2UjA= github.com/muesli/mango-pflag v0.1.0 h1:UADqbYgpUyRoBja3g6LUL+3LErjpsOwaC9ywvBWe7Sg= @@ -132,8 +125,6 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8= github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/neilotoole/fscache v0.0.0-20231126193352-f79ad8b71381 h1:yq4OXuvSTMzvCm2m9FlpUnL8PbVQGW4qv+s8uRRtxK8= -github.com/neilotoole/fscache v0.0.0-20231126193352-f79ad8b71381/go.mod h1:GdelWtWN0Gbf2uE+rEZ4GinWtfV6PobHgdrQ4IrB504= github.com/neilotoole/fscache v0.0.0-20231203152859-7f9a76169e92 h1:VyBPnrrbiTiN0IlN6ULtQ+tGVIQLCbEtJPpL7Cv/Jog= github.com/neilotoole/fscache v0.0.0-20231203152859-7f9a76169e92/go.mod h1:NnvBCFGuEGIP73eSn+XbDluSR0aj61e0P1cIaIEOE04= github.com/neilotoole/shelleditor v0.4.1 h1:74LEw2mVo3jtNw2BjII6RSss9DXgEqAbmCQDDiJvzO0= @@ -165,20 +156,19 @@ github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUz github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryboe/q v1.0.20 h1:FDaGYR2WrXMrWFzklRXWJZvhAPQr07SMLVf3bCAGhVQ= github.com/ryboe/q v1.0.20/go.mod h1:IiqlbBPRrComXDcFXCKyIGle2yPqmgPKLJAMJQZjcgA= -github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= -github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +github.com/samber/lo v1.39.0 h1:4gTz1wUhNYLhFSKl6O+8peW0v2F4BCY034GRpU9WnuA= +github.com/samber/lo v1.39.0/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc= github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg= -github.com/segmentio/encoding v0.3.6 h1:E6lVLyDPseWEulBmCmAKPanDd3jiyGDo5gMcugCRwZQ= -github.com/segmentio/encoding v0.3.6/go.mod h1:n0JeuIqEQrQoPDGsjo8UNd1iA0U8d8+oHAA4E3G3OxM= +github.com/segmentio/encoding v0.3.7 h1:2rSfoktCoC1viI0DgsD+cvM4x2aze5/gza8B6/Cxqjo= +github.com/segmentio/encoding v0.3.7/go.mod h1:n0JeuIqEQrQoPDGsjo8UNd1iA0U8d8+oHAA4E3G3OxM= github.com/sethvargo/go-retry v0.2.4 h1:T+jHEQy/zKJf5s95UkguisicE0zuF9y7+/vgz08Ocec= github.com/sethvargo/go-retry v0.2.4/go.mod h1:1afjQuvh7s4gflMObvjLPaWgluLLyhA1wmVZ6KLpICw= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA= -github.com/spf13/cast v1.5.1/go.mod h1:b9PdjNptOpzXr7Rq1q9gJML/2cdGQAo69NKzQ10KN48= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= @@ -194,18 +184,16 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/vbauerster/mpb/v8 v8.6.2 h1:9EhnJGQRtvgDVCychJgR96EDCOqgg2NsMuk5JUcX4DA= -github.com/vbauerster/mpb/v8 v8.6.2/go.mod h1:oVJ7T+dib99kZ/VBjoBaC8aPXiSAihnzuKmotuihyFo= -github.com/xo/dburl v0.18.2 h1:9xqcVf+JEV7bcUa1OjCsoax06roohYFdye6xkvBKo50= -github.com/xo/dburl v0.18.2/go.mod h1:B7/G9FGungw6ighV8xJNwWYQPMfn3gsi2sn5SE8Bzco= +github.com/vbauerster/mpb/v8 v8.7.0 h1:n2LTGyol7qqNBcLQn8FL5Bga2O8CGF75OOYsJVFsfMg= +github.com/vbauerster/mpb/v8 v8.7.0/go.mod h1:0RgdqeTpu6cDbdWeSaDvEvfgm9O598rBnRZ09HKaV0k= +github.com/xo/dburl v0.19.1 h1:z/K2i8zVf6aRwQ8Szz7MGEUw0VC2472D9SlBqdHDQCU= +github.com/xo/dburl v0.19.1/go.mod h1:B7/G9FGungw6ighV8xJNwWYQPMfn3gsi2sn5SE8Bzco= +github.com/xuri/efp v0.0.0-20230802181842-ad255f2331ca h1:uvPMDVyP7PXMMioYdyPH+0O+Ta/UO1WFfNYMO3Wz0eg= github.com/xuri/efp v0.0.0-20230802181842-ad255f2331ca/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI= -github.com/xuri/efp v0.0.0-20231025114914-d1ff6096ae53 h1:Chd9DkqERQQuHpXjR/HSV1jLZA6uaoiwwH3vSuF3IW0= -github.com/xuri/efp v0.0.0-20231025114914-d1ff6096ae53/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI= github.com/xuri/excelize/v2 v2.8.0 h1:Vd4Qy809fupgp1v7X+nCS/MioeQmYVVzi495UCTqB7U= github.com/xuri/excelize/v2 v2.8.0/go.mod h1:6iA2edBTKxKbZAa7X5bDhcCg51xdOn1Ar5sfoXRGrQg= +github.com/xuri/nfp v0.0.0-20230819163627-dc951e3ffe1a h1:Mw2VNrNNNjDtw68VsEj2+st+oCSn4Uz7vZw6TbhcV1o= github.com/xuri/nfp v0.0.0-20230819163627-dc951e3ffe1a/go.mod h1:WwHg+CVyzlv/TX9xqBFXEZAuxOPxn2k1GNHwG41IIUQ= -github.com/xuri/nfp v0.0.0-20230919160717-d98342af3f05 h1:qhbILQo1K3mphbwKh1vNm4oGezE1eF9fQWmNiIpSfI4= -github.com/xuri/nfp v0.0.0-20230919160717-d98342af3f05/go.mod h1:WwHg+CVyzlv/TX9xqBFXEZAuxOPxn2k1GNHwG41IIUQ= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= @@ -215,10 +203,10 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= -golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= -golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= -golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= -golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/exp v0.0.0-20231127185646-65229373498e h1:Gvh4YaCaXNs6dKTlfgismwWZKyjVZXwOPfIyUaqU3No= +golang.org/x/exp v0.0.0-20231127185646-65229373498e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= golang.org/x/image v0.11.0 h1:ds2RoQvBvYTiJkwpSFDwCcDFNX7DqjL2WsUgTNk0Ooo= golang.org/x/image v0.11.0/go.mod h1:bglhjqbqVuEb9e9+eNR45Jfu7D+T4Qan+NhQk8Ck2P8= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -232,8 +220,8 @@ golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI= -golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= -golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= +golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= +golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -252,16 +240,16 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= -golang.org/x/term v0.14.0 h1:LGK9IlZ8T9jvdy6cTdfKUCltatMFOehAQo9SRC46UQ8= -golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= +golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= @@ -276,12 +264,13 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= -golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From de1ea334edd84d270ace6ea51eb3162bc7ef81ac Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 09:56:46 -0700 Subject: [PATCH 041/195] Improved error handling --- cli/error.go | 9 ++++++++- go.mod | 4 +++- go.sum | 8 ++++---- libsq/core/ioz/contextio/contextio.go | 19 +++---------------- libsq/core/progress/progressio.go | 20 +++----------------- libsq/libsq.go | 12 +++++++++++- 6 files changed, 32 insertions(+), 40 deletions(-) diff --git a/cli/error.go b/cli/error.go index 4321e35a2..ad218b951 100644 --- a/cli/error.go +++ b/cli/error.go @@ -53,7 +53,14 @@ func printError(ctx context.Context, ru *run.Run, err error) { cmdName = cmd.Name() } - log.Error("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName) + if errz.IsErrContext(err) { + // If it's a context error, e.g. the user cancelled, we'll log it as + // a warning instead of as an error. + log.Warn("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName) + } else { + log.Error("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName) + } + wrtrs := ru.Writers if wrtrs != nil && wrtrs.Error != nil { // If we have an errorWriter, we print to it diff --git a/go.mod b/go.mod index b5abd90ed..830be59c2 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/muesli/mango-cobra v1.2.0 github.com/muesli/roff v0.1.0 github.com/ncruces/go-strftime v0.1.9 - github.com/neilotoole/fscache v0.0.0-20231203152859-7f9a76169e92 + github.com/neilotoole/fscache v0.0.0-20231203162946-c9808f16552e github.com/neilotoole/shelleditor v0.4.1 github.com/neilotoole/slogt v1.1.0 github.com/nightlyone/lockfile v1.0.0 @@ -91,3 +91,5 @@ require ( golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/djherbis/stream v1.4.0 => github.com/neilotoole/djherbis-stream v0.0.0-20231203160853-609f47afedda diff --git a/go.sum b/go.sum index 89d38cb45..855bf2ed6 100644 --- a/go.sum +++ b/go.sum @@ -37,8 +37,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/djherbis/atime v1.1.0 h1:rgwVbP/5by8BvvjBNrbh64Qz33idKT3pSnMSJsxhi0g= github.com/djherbis/atime v1.1.0/go.mod h1:28OF6Y8s3NQWwacXc5eZTsEsiMzp7LF8MbXE+XJPdBE= -github.com/djherbis/stream v1.4.0 h1:aVD46WZUiq5kJk55yxJAyw6Kuera6kmC3i2vEQyW/AE= -github.com/djherbis/stream v1.4.0/go.mod h1:cqjC1ZRq3FFwkGmUtHwcldbnW8f0Q4YuVsGW1eAFtOk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ecnepsnai/osquery v1.0.1 h1:i96n/3uqcafKZtRYmXVNqekKbfrIm66q179mWZ/Y2Aw= @@ -125,8 +123,10 @@ github.com/muesli/roff v0.1.0 h1:YD0lalCotmYuF5HhZliKWlIx7IEhiXeSfq7hNjFqGF8= github.com/muesli/roff v0.1.0/go.mod h1:pjAHQM9hdUUwm/krAfrLGgJkXJ+YuhtsfZ42kieB2Ig= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/neilotoole/fscache v0.0.0-20231203152859-7f9a76169e92 h1:VyBPnrrbiTiN0IlN6ULtQ+tGVIQLCbEtJPpL7Cv/Jog= -github.com/neilotoole/fscache v0.0.0-20231203152859-7f9a76169e92/go.mod h1:NnvBCFGuEGIP73eSn+XbDluSR0aj61e0P1cIaIEOE04= +github.com/neilotoole/djherbis-stream v0.0.0-20231203160853-609f47afedda h1:/DeuPyW2WdGnI45Z4Bdct9+OAveM3/gIXFDau5rameY= +github.com/neilotoole/djherbis-stream v0.0.0-20231203160853-609f47afedda/go.mod h1:cqjC1ZRq3FFwkGmUtHwcldbnW8f0Q4YuVsGW1eAFtOk= +github.com/neilotoole/fscache v0.0.0-20231203162946-c9808f16552e h1:iHZVemEeZ5iK67hJAT4eInA3ljwISjyX72rUtL/DCzM= +github.com/neilotoole/fscache v0.0.0-20231203162946-c9808f16552e/go.mod h1:GZrpRf9MA0R0OZy6zovlS0Gs3VPFZ+L0jq45Rz+Sm2g= github.com/neilotoole/shelleditor v0.4.1 h1:74LEw2mVo3jtNw2BjII6RSss9DXgEqAbmCQDDiJvzO0= github.com/neilotoole/shelleditor v0.4.1/go.mod h1:QanOZN4syDMp/L0SKwZb47Mh49mvLWX3ja5YfbYDDjo= github.com/neilotoole/slogt v1.1.0 h1:c7qE92sq+V0yvCuaxph+RQ2jOKL61c4hqS1Bv9W7FZE= diff --git a/libsq/core/ioz/contextio/contextio.go b/libsq/core/ioz/contextio/contextio.go index 535db2489..a600e396f 100644 --- a/libsq/core/ioz/contextio/contextio.go +++ b/libsq/core/ioz/contextio/contextio.go @@ -96,14 +96,7 @@ func (w *writeCloser) Close() error { select { case <-w.ctx.Done(): - ctxErr := w.ctx.Err() - switch { - case closeErr == nil, - errz.IsErrContext(closeErr): - return ctxErr - default: - return errors.Join(ctxErr, closeErr) - } + return w.ctx.Err() default: return closeErr } @@ -164,14 +157,8 @@ func (rc *readCloser) Close() error { select { case <-rc.ctx.Done(): - ctxErr := rc.ctx.Err() - switch { - case closeErr == nil, - errz.IsErrContext(closeErr): - return ctxErr - default: - return errors.Join(ctxErr, closeErr) - } + return rc.ctx.Err() + default: return closeErr } diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/progressio.go index c21c278fd..7e880d3ea 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/progressio.go @@ -20,7 +20,6 @@ package progress import ( "context" - "errors" "io" "github.com/neilotoole/sq/libsq/core/errz" @@ -107,14 +106,8 @@ func (w *progWriter) Close() error { select { case <-w.ctx.Done(): - ctxErr := w.ctx.Err() - switch { - case closeErr == nil, - errz.IsErrContext(closeErr): - return ctxErr - default: - return errors.Join(ctxErr, closeErr) - } + return w.ctx.Err() + default: return closeErr } @@ -174,14 +167,7 @@ func (r *progReader) Close() error { select { case <-r.ctx.Done(): - ctxErr := r.ctx.Err() - switch { - case closeErr == nil, - errz.IsErrContext(closeErr): - return ctxErr - default: - return errors.Join(ctxErr, closeErr) - } + return r.ctx.Err() default: return closeErr } diff --git a/libsq/libsq.go b/libsq/libsq.go index e4e54d3f4..aa0a44b34 100644 --- a/libsq/libsq.go +++ b/libsq/libsq.go @@ -11,6 +11,7 @@ package libsq import ( "context" + "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" @@ -146,7 +147,16 @@ func QuerySQL(ctx context.Context, pool driver.Pool, db sqlz.DB, rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return errz.Wrapf(errw(err), `SQL query against %s failed: %s`, pool.Source().Handle, query) + err = errz.Wrapf(errw(err), `SQL query against %s failed: %s`, pool.Source().Handle, query) + select { + case <-ctx.Done(): + // If the context was cancelled, it's probably more accurate + // to just return the context error. + log.Debug("Error received, but context was done", lga.Err, err) + return ctx.Err() + default: + return err + } } defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows) From c702662527fcd07acb2faa3466195f17793e7bbb Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 11:15:19 -0700 Subject: [PATCH 042/195] Adding cache cmds --- cli/cli.go | 3 + cli/cmd_config_cache.go | 116 +++++++++++++++++++++++++++ cli/logging.go | 19 ++++- cli/options.go | 1 + cli/output/jsonw/configwriter.go | 22 +++++ cli/output/tablew/configwriter.go | 21 +++++ cli/output/writers.go | 7 ++ cli/output/yamlw/configwriter.go | 22 +++++ libsq/core/ioz/ioz.go | 15 ++++ libsq/core/lg/devlog/tint/handler.go | 15 ++++ libsq/core/lg/lga/lga.go | 1 + libsq/driver/sources.go | 7 +- libsq/source/cache.go | 6 +- 13 files changed, 250 insertions(+), 5 deletions(-) create mode 100644 cli/cmd_config_cache.go diff --git a/cli/cli.go b/cli/cli.go index 07211a1c4..e3eca7c19 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -225,6 +225,9 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { addCmd(ru, configCmd, newConfigSetCmd()) addCmd(ru, configCmd, newConfigLocationCmd()) addCmd(ru, configCmd, newConfigEditCmd()) + cacheCmd := addCmd(ru, configCmd, newConfigCacheCmd()) + addCmd(ru, cacheCmd, newConfigCacheLocationCmd()) + addCmd(ru, cacheCmd, newConfigCacheInfoCmd()) addCmd(ru, rootCmd, newCompletionCmd()) addCmd(ru, rootCmd, newVersionCmd()) diff --git a/cli/cmd_config_cache.go b/cli/cmd_config_cache.go new file mode 100644 index 000000000..384b2c666 --- /dev/null +++ b/cli/cmd_config_cache.go @@ -0,0 +1,116 @@ +package cli + +import ( + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/source" + "github.com/spf13/cobra" + + "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/cli/run" +) + +func newConfigCacheCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "cache", + Args: cobra.NoArgs, + Short: "Manage cache", + Long: `Manage cache.`, + RunE: func(cmd *cobra.Command, args []string) error { + return cmd.Help() + }, + // FIXME: add examples + // Example: ` # Print config location + //$ sq config location + // + //# Show base config + //$ sq config ls + // + //# Show base config including unset and default values. + //$ sq config ls -v + // + //# Show base config in maximum detail (YAML format) + //$ sq config ls -yv + // + //# Get base value of an option + //$ sq config get format + // + //# Get source-specific value of an option + //$ sq config get --src @sakila conn.max-open + // + //# Set base option value + //$ sq config set format json + // + //# Set source-specific option value + //$ sq config set --src @sakila conn.max-open 50 + // + //# Help for an option + //$ sq config set format --help + // + //# Edit base config in $EDITOR + //$ sq config edit + // + //# Edit config for source in $EDITOR + //$ sq config edit @sakila + // + //# Delete option (reset to default value) + //$ sq config set -D log.level`, + } + + return cmd +} + +func newConfigCacheLocationCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "location", + Aliases: []string{"loc"}, + Short: "Print cache location", + Long: "Print cache location.", + Args: cobra.ExactArgs(0), + RunE: execConfigCacheLocation, + Example: ` $ sq config cache location + /Users/neilotoole/Library/Caches/sq`, + } + + addTextFormatFlags(cmd) + cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) + cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage) + return cmd +} + +func execConfigCacheLocation(cmd *cobra.Command, _ []string) error { + dir := source.CacheDirPath() + ru := run.FromContext(cmd.Context()) + return ru.Writers.Config.CacheLocation(dir) +} + +func newConfigCacheInfoCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "info", + Short: "Print cache info", + Long: "Print cache info.", + Args: cobra.ExactArgs(0), + RunE: execConfigCacheInfo, + Example: ` $ sq config cache info + /Users/neilotoole/Library/Caches/sq (1.2MB)`, + } + + addTextFormatFlags(cmd) + cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) + cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage) + return cmd +} + +func execConfigCacheInfo(cmd *cobra.Command, _ []string) error { + dir := source.CacheDirPath() + ru := run.FromContext(cmd.Context()) + size, err := ioz.DirSize(dir) + if err != nil { + lg.FromContext(cmd.Context()).Warn("Could not determine cache size", + lga.Path, dir, lga.Err, err) + size = -1 + } + + return ru.Writers.Config.CacheInfo(dir, size) +} diff --git a/cli/logging.go b/cli/logging.go index 02099a93b..39f728928 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -49,6 +49,16 @@ var ( `Log level, one of: DEBUG, INFO, WARN, ERROR`, "Log level, one of: DEBUG, INFO, WARN, ERROR.", ) + + OptLogDevMode = options.NewBool( + "log.devmode", + "", + false, + 0, + false, + "Log in devmode", + "Log in devmode.", + ) ) // defaultLogging returns a *slog.Logger, its slog.Handler, and @@ -92,7 +102,14 @@ func defaultLogging(ctx context.Context, osArgs []string, cfg *config.Config, } closer = logFile.Close - h = devlog.NewHandler(logFile, lvl) + devMode := OptLogDevMode.Get(cfg.Options) + + if devMode { + h = devlog.NewHandler(logFile, lvl) + } else { + h = newJSONHandler(logFile, lvl) + } + //h = devlog.NewHandler(logFile, lvl) // h = newJSONHandler(logFile, lvl) return slog.New(h), h, closer, nil } diff --git a/cli/options.go b/cli/options.go index 90f0a60dc..1a5365432 100644 --- a/cli/options.go +++ b/cli/options.go @@ -171,6 +171,7 @@ func RegisterDefaultOpts(reg *options.Registry) { OptLogEnabled, OptLogFile, OptLogLevel, + OptLogDevMode, OptDiffNumLines, OptDiffDataFormat, driver.OptConnMaxOpen, diff --git a/cli/output/jsonw/configwriter.go b/cli/output/jsonw/configwriter.go index 4425b461f..3001e7e59 100644 --- a/cli/output/jsonw/configwriter.go +++ b/cli/output/jsonw/configwriter.go @@ -21,6 +21,28 @@ func NewConfigWriter(out io.Writer, pr *output.Printing) output.ConfigWriter { return &configWriter{out: out, pr: pr} } +// CacheLocation implements output.ConfigWriter. +func (w *configWriter) CacheLocation(loc string) error { + m := map[string]string{"location": loc} + return writeJSON(w.out, w.pr, m) +} + +// CacheInfo implements output.ConfigWriter. It simply +// delegates to CacheLocation. +func (w *configWriter) CacheInfo(loc string, size int64) error { + type cacheInfo struct { + Location string `json:"location"` + Size *int64 `json:"size,omitempty"` + } + + ci := cacheInfo{Location: loc} + if size != -1 { + ci.Size = &size + } + + return writeJSON(w.out, w.pr, ci) +} + // Location implements output.ConfigWriter. func (w *configWriter) Location(loc, origin string) error { type cfgInfo struct { diff --git a/cli/output/tablew/configwriter.go b/cli/output/tablew/configwriter.go index edd86d413..05ba811f0 100644 --- a/cli/output/tablew/configwriter.go +++ b/cli/output/tablew/configwriter.go @@ -2,6 +2,8 @@ package tablew import ( "fmt" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/stringz" "io" "github.com/fatih/color" @@ -36,6 +38,25 @@ func (w *configWriter) Location(path, origin string) error { return nil } +// CacheLocation implements output.ConfigWriter. +func (w *configWriter) CacheLocation(loc string) error { + _, err := fmt.Fprintln(w.tbl.out, loc) + return errz.Err(err) +} + +// CacheInfo implements output.ConfigWriter. It simply +// delegates to CacheLocation. +func (w *configWriter) CacheInfo(loc string, size int64) error { + s := loc + " " + if size == -1 { + s += w.tbl.pr.Error.Sprint("(size unavailable)") + } else { + s += w.tbl.pr.Faint.Sprintf("(%s)", stringz.ByteSized(size, 1, "")) + } + _, err := fmt.Fprintln(w.tbl.out, s) + return err +} + // Opt implements output.ConfigWriter. func (w *configWriter) Opt(o options.Options, opt options.Opt) error { if o == nil || opt == nil { diff --git a/cli/output/writers.go b/cli/output/writers.go index 530866443..0049d1ca3 100644 --- a/cli/output/writers.go +++ b/cli/output/writers.go @@ -139,6 +139,13 @@ type ConfigWriter interface { // UnsetOption is called when an option is unset. UnsetOption(opt options.Opt) error + + // CacheLocation prints the cache location. + CacheLocation(loc string) error + + // CacheInfo prints cache info. Set arg size to -1 to indicate + // that the size of the cache could not be calculated. + CacheInfo(loc string, size int64) error } // Writers is a container for the various output Writers. diff --git a/cli/output/yamlw/configwriter.go b/cli/output/yamlw/configwriter.go index 83279bd65..774dca852 100644 --- a/cli/output/yamlw/configwriter.go +++ b/cli/output/yamlw/configwriter.go @@ -39,6 +39,28 @@ func (w *configWriter) Location(loc, origin string) error { return writeYAML(w.out, w.p, c) } +// CacheLocation implements output.ConfigWriter. +func (w *configWriter) CacheLocation(loc string) error { + m := map[string]string{"location": loc} + return writeYAML(w.out, w.p, m) +} + +// CacheInfo implements output.ConfigWriter. It simply +// delegates to CacheLocation. +func (w *configWriter) CacheInfo(loc string, size int64) error { + type cacheInfo struct { + Location string `yaml:"location"` + Size *int64 `yaml:"size,omitempty"` + } + + ci := cacheInfo{Location: loc} + if size != -1 { + ci.Size = &size + } + + return writeYAML(w.out, w.p, ci) +} + // Opt implements output.ConfigWriter. func (w *configWriter) Opt(o options.Options, opt options.Opt) error { if o == nil || opt == nil { diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 54dbc17a3..be6c9560b 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -370,3 +370,18 @@ func (w *WrittenWriter) Write(p []byte) (n int, err error) { } return n, err } + +// DirSize returns total size of all regular files in path. +func DirSize(path string) (int64, error) { + var size int64 + err := filepath.Walk(path, func(_ string, fi os.FileInfo, err error) error { + if err != nil { + return err + } + if !fi.IsDir() && fi.Mode().IsRegular() { + size += fi.Size() + } + return err + }) + return size, err +} diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 3c2572e48..e8a79070e 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -220,12 +220,27 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { } } + msgColor := ansiBrightGreen + switch r.Level { + case slog.LevelDebug: + msgColor = ansiBrightGreen + case slog.LevelWarn: + msgColor = ansiBrightYellow + case slog.LevelError: + msgColor = ansiBrightRed + case slog.LevelInfo: + msgColor = ansiBlue + } // write message if rep == nil { + buf.WriteStringIf(!h.noColor, msgColor) buf.WriteString(r.Message) + buf.WriteStringIf(!h.noColor, ansiReset) buf.WriteByte(' ') } else if a := rep(nil /* groups */, slog.String(slog.MessageKey, r.Message)); a.Key != "" { + buf.WriteStringIf(!h.noColor, msgColor) h.appendValue(buf, a.Value, false) + buf.WriteStringIf(!h.noColor, ansiReset) buf.WriteByte(' ') } diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index eed454dea..fa31dca30 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -32,6 +32,7 @@ const ( Key = "key" Kind = "kind" Loc = "loc" + Lock = "lock" New = "new" Old = "old" Opts = "opts" diff --git a/libsq/driver/sources.go b/libsq/driver/sources.go index ad1872d69..340c7b0d0 100644 --- a/libsq/driver/sources.go +++ b/libsq/driver/sources.go @@ -297,7 +297,10 @@ func (ss *Sources) acquireLock(ctx context.Context, src *source.Source) (lockfil } err = retry.Do(ctx, time.Second*5, - lock.TryLock, + func() error { + lg.FromContext(ctx).Debug("Attempting to acquire cache lock", lga.Lock, lock) + return lock.TryLock() + }, func(err error) bool { var temporaryError lockfile.TemporaryError return errors.As(err, &temporaryError) @@ -307,7 +310,7 @@ func (ss *Sources) acquireLock(ctx context.Context, src *source.Source) (lockfil return "", errz.Wrap(err, "failed to get lock") } - lg.FromContext(ctx).Debug("Acquired cache lock", "lock", lock) + lg.FromContext(ctx).Debug("Acquired cache lock", lga.Lock, lock) return lock, nil } diff --git a/libsq/source/cache.go b/libsq/source/cache.go index f3dd4e190..f5455e100 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -47,8 +47,10 @@ func CacheDirPath() (dir string) { if dir, err = os.UserCacheDir(); err != nil { // Some systems may not have a user cache dir, so we fall back // to the system temp dir. - dir = os.TempDir() + dir = filepath.Join(os.TempDir(), "sq", "cache") + return dir } - dir = filepath.Join(dir, "sq", "cache") + + dir = filepath.Join(dir, "sq") return dir } From 5505cec577c52384d2624319eec18cf17cebb075 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 12:41:09 -0700 Subject: [PATCH 043/195] Implemented cmd cache tree --- cli/cli.go | 2 + cli/cmd_config_cache.go | 117 +++++++++++++++++++++++++++++- cli/logging.go | 2 +- cli/output/jsonw/configwriter.go | 5 +- cli/output/printing.go | 18 ++++- cli/output/tablew/configwriter.go | 15 +++- cli/output/writers.go | 2 +- cli/output/yamlw/configwriter.go | 5 +- go.mod | 1 + go.sum | 2 + libsq/core/ioz/ioz.go | 9 +++ libsq/libsq.go | 3 +- libsq/source/files.go | 2 +- 13 files changed, 169 insertions(+), 14 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index e3eca7c19..0d801227b 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -228,6 +228,8 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { cacheCmd := addCmd(ru, configCmd, newConfigCacheCmd()) addCmd(ru, cacheCmd, newConfigCacheLocationCmd()) addCmd(ru, cacheCmd, newConfigCacheInfoCmd()) + addCmd(ru, cacheCmd, newConfigCacheClearCmd()) + addCmd(ru, cacheCmd, newConfigCacheTreeCmd()) addCmd(ru, rootCmd, newCompletionCmd()) addCmd(ru, rootCmd, newVersionCmd()) diff --git a/cli/cmd_config_cache.go b/cli/cmd_config_cache.go index 384b2c666..084dedbf3 100644 --- a/cli/cmd_config_cache.go +++ b/cli/cmd_config_cache.go @@ -1,9 +1,18 @@ package cli import ( + "github.com/a8m/tree" + "github.com/a8m/tree/ostree" + "io" + "os" + "path/filepath" + + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/spf13/cobra" @@ -112,5 +121,111 @@ func execConfigCacheInfo(cmd *cobra.Command, _ []string) error { size = -1 } - return ru.Writers.Config.CacheInfo(dir, size) + enabled := driver.OptIngestCache.Get(ru.Config.Options) + return ru.Writers.Config.CacheInfo(dir, enabled, size) +} + +func newConfigCacheClearCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "clear", + Short: "Clear cache", + Long: "Clear cache. May cause issues if another sq instance is running.", + Args: cobra.ExactArgs(0), + RunE: execConfigCacheClear, + Example: ` $ sq config cache clear`, + } + + return cmd +} + +func execConfigCacheClear(cmd *cobra.Command, _ []string) error { + log := lg.FromContext(cmd.Context()) + cacheDir := source.CacheDirPath() + if !ioz.DirExists(cacheDir) { + return nil + } + + // Instead of directly deleting the existing cache dir, we first + // move it to /tmp, and then try to delete it. This should probably + // help with the situation where another sq instance has an open pid + // lock in the cache dir. + tmpLoc := filepath.Join(os.TempDir(), "sq", "dead_cache_"+stringz.Uniq8()) + if err := os.Rename(cacheDir, tmpLoc); err != nil { + return errz.Wrap(err, "clear cache: relocate") + } + + deleteErr := os.RemoveAll(tmpLoc) + if deleteErr != nil { + log.Warn("Could not delete relocated cache dir", lga.Path, tmpLoc, lga.Err, deleteErr) + } + + if err := os.MkdirAll(cacheDir, 0o750); err != nil { + return errz.Wrap(err, "clear cache") + } + + return nil +} + +func newConfigCacheTreeCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "tree", + Short: "Print tree view of cache dir", + Long: "Print tree view of cache dir.", + Args: cobra.ExactArgs(0), + RunE: execConfigCacheTree, + Example: ` $ sq config cache tree`, + } + + return cmd +} + +func execConfigCacheTree(cmd *cobra.Command, _ []string) error { + ru := run.FromContext(cmd.Context()) + cacheDir := source.CacheDirPath() + if !ioz.DirExists(cacheDir) { + return nil + } + return printFileTree(ru.Out, cacheDir) +} + +func printFileTree(w io.Writer, loc string) error { + opts := &tree.Options{ + Fs: new(ostree.FS), + OutFile: w, + All: false, + //DirsOnly: false, + //FullPath: false, + //IgnoreCase: false, + //FollowLink: false, + //DeepLevel: 0, + //Pattern: "", + //IPattern: "", + //MatchDirs: false, + //Prune: false, + //ByteSize: false, + //UnitSize: true, + //FileMode: false, + //ShowUid: false, + //ShowGid: false, + //LastMod: false, + //Quotes: false, + //Inodes: false, + //Device: false, + //NoSort: false, + //VerSort: false, + //ModSort: false, + //DirSort: false, + //NameSort: false, + //SizeSort: false, + //CTimeSort: false, + //ReverSort: false, + //NoIndent: false, + Colorize: true, + //Color: nil, + } + + inf := tree.New(loc) + _, _ = inf.Visit(opts) + inf.Print(opts) + return nil } diff --git a/cli/logging.go b/cli/logging.go index 39f728928..ed8f11e56 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -109,7 +109,7 @@ func defaultLogging(ctx context.Context, osArgs []string, cfg *config.Config, } else { h = newJSONHandler(logFile, lvl) } - //h = devlog.NewHandler(logFile, lvl) + // h = devlog.NewHandler(logFile, lvl) // h = newJSONHandler(logFile, lvl) return slog.New(h), h, closer, nil } diff --git a/cli/output/jsonw/configwriter.go b/cli/output/jsonw/configwriter.go index 3001e7e59..7692ebaa5 100644 --- a/cli/output/jsonw/configwriter.go +++ b/cli/output/jsonw/configwriter.go @@ -29,13 +29,14 @@ func (w *configWriter) CacheLocation(loc string) error { // CacheInfo implements output.ConfigWriter. It simply // delegates to CacheLocation. -func (w *configWriter) CacheInfo(loc string, size int64) error { +func (w *configWriter) CacheInfo(loc string, enabled bool, size int64) error { type cacheInfo struct { Location string `json:"location"` + Enabled bool `json:"enabled"` Size *int64 `json:"size,omitempty"` } - ci := cacheInfo{Location: loc} + ci := cacheInfo{Location: loc, Enabled: enabled} if size != -1 { ci.Size = &size } diff --git a/cli/output/printing.go b/cli/output/printing.go index e26b4242f..142b51e8f 100644 --- a/cli/output/printing.go +++ b/cli/output/printing.go @@ -111,9 +111,15 @@ type Printing struct { // DiffNormal is the color for regular diff text. DiffNormal *color.Color + // Disabled is the color for disabled elements. + Disabled *color.Color + // Duration is the color for time duration values. Duration *color.Color + // Enabled is the color for enabled elements. + Enabled *color.Color + // Error is the color for error elements such as an error message. Error *color.Color @@ -152,6 +158,9 @@ type Printing struct { // Success is the color for success elements. Success *color.Color + + // Warning is the color for warning elements. + Warning *color.Color } // NewPrinting returns a Printing instance. Color and pretty-print @@ -181,7 +190,9 @@ func NewPrinting() *Printing { DiffNormal: color.New(color.Faint), DiffPlus: color.New(color.FgGreen), DiffSection: color.New(color.FgCyan), + Disabled: color.New(color.FgYellow, color.Faint), Duration: color.New(color.FgGreen, color.Faint), + Enabled: color.New(color.FgGreen, color.Faint), Error: color.New(color.FgRed, color.Bold), Faint: color.New(color.Faint), Handle: color.New(color.FgBlue), @@ -195,6 +206,7 @@ func NewPrinting() *Printing { Punc: color.New(color.Bold), String: color.New(color.FgGreen), Success: color.New(color.FgGreen, color.Bold), + Warning: color.New(color.FgYellow), } pr.EnableColor(true) @@ -229,7 +241,9 @@ func (pr *Printing) Clone() *Printing { pr2.DiffHeader = lo.ToPtr(*pr.DiffHeader) pr2.DiffSection = lo.ToPtr(*pr.DiffSection) pr2.DiffNormal = lo.ToPtr(*pr.DiffNormal) + pr2.Disabled = lo.ToPtr(*pr.Disabled) pr2.Duration = lo.ToPtr(*pr.Duration) + pr2.Enabled = lo.ToPtr(*pr.Enabled) pr2.Error = lo.ToPtr(*pr.Error) pr2.Faint = lo.ToPtr(*pr.Faint) pr2.Handle = lo.ToPtr(*pr.Handle) @@ -243,6 +257,7 @@ func (pr *Printing) Clone() *Printing { pr2.Punc = lo.ToPtr(*pr.Punc) pr2.String = lo.ToPtr(*pr.String) pr2.Success = lo.ToPtr(*pr.Success) + pr2.Warning = lo.ToPtr(*pr.Warning) return pr2 } @@ -272,9 +287,10 @@ func (pr *Printing) colors() []*color.Color { return []*color.Color{ pr.Active, pr.Bold, pr.Bold, pr.Bytes, pr.Datetime, pr.Duration, pr.DiffHeader, pr.DiffMinus, pr.DiffPlus, pr.DiffNormal, pr.DiffSection, + pr.Disabled, pr.Enabled, pr.Error, pr.Faint, pr.Handle, pr.Header, pr.Hilite, pr.Key, pr.Location, pr.Normal, pr.Null, pr.Number, - pr.Punc, pr.String, pr.Success, + pr.Punc, pr.String, pr.Success, pr.Warning, } } diff --git a/cli/output/tablew/configwriter.go b/cli/output/tablew/configwriter.go index 05ba811f0..cc432e252 100644 --- a/cli/output/tablew/configwriter.go +++ b/cli/output/tablew/configwriter.go @@ -2,9 +2,10 @@ package tablew import ( "fmt" + "io" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/stringz" - "io" "github.com/fatih/color" "github.com/samber/lo" @@ -46,10 +47,16 @@ func (w *configWriter) CacheLocation(loc string) error { // CacheInfo implements output.ConfigWriter. It simply // delegates to CacheLocation. -func (w *configWriter) CacheInfo(loc string, size int64) error { - s := loc + " " +func (w *configWriter) CacheInfo(loc string, enabled bool, size int64) error { + const sp = " " + s := loc + sp + if enabled { + s += w.tbl.pr.Enabled.Sprint("enabled") + sp + } else { + s += w.tbl.pr.Disabled.Sprint("disabled") + sp + } if size == -1 { - s += w.tbl.pr.Error.Sprint("(size unavailable)") + s += w.tbl.pr.Warning.Sprint("(size unavailable)") } else { s += w.tbl.pr.Faint.Sprintf("(%s)", stringz.ByteSized(size, 1, "")) } diff --git a/cli/output/writers.go b/cli/output/writers.go index 0049d1ca3..bddff1dc4 100644 --- a/cli/output/writers.go +++ b/cli/output/writers.go @@ -145,7 +145,7 @@ type ConfigWriter interface { // CacheInfo prints cache info. Set arg size to -1 to indicate // that the size of the cache could not be calculated. - CacheInfo(loc string, size int64) error + CacheInfo(loc string, enabled bool, size int64) error } // Writers is a container for the various output Writers. diff --git a/cli/output/yamlw/configwriter.go b/cli/output/yamlw/configwriter.go index 774dca852..80fc181e9 100644 --- a/cli/output/yamlw/configwriter.go +++ b/cli/output/yamlw/configwriter.go @@ -47,13 +47,14 @@ func (w *configWriter) CacheLocation(loc string) error { // CacheInfo implements output.ConfigWriter. It simply // delegates to CacheLocation. -func (w *configWriter) CacheInfo(loc string, size int64) error { +func (w *configWriter) CacheInfo(loc string, enabled bool, size int64) error { type cacheInfo struct { Location string `yaml:"location"` + Enabled bool `yaml:"enabled"` Size *int64 `yaml:"size,omitempty"` } - ci := cacheInfo{Location: loc} + ci := cacheInfo{Location: loc, Enabled: enabled} if size != -1 { ci.Size = &size } diff --git a/go.mod b/go.mod index 830be59c2..03570c7ff 100644 --- a/go.mod +++ b/go.mod @@ -57,6 +57,7 @@ require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.0 // indirect github.com/VividCortex/ewma v1.2.0 // indirect + github.com/a8m/tree v0.0.0-20230208161321-36ae24ddad15 // indirect github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/djherbis/atime v1.1.0 // indirect diff --git a/go.sum b/go.sum index 855bf2ed6..65ddf81a5 100644 --- a/go.sum +++ b/go.sum @@ -20,6 +20,8 @@ github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1ow= github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= +github.com/a8m/tree v0.0.0-20230208161321-36ae24ddad15 h1:t3qDzTv8T15tVVhJHHgY7hX5jiIz67xE2SxWQ2ehjH4= +github.com/a8m/tree v0.0.0-20230208161321-36ae24ddad15/go.mod h1:j5astEcUkZQX8lK+KKlQ3NRQ50f4EE8ZjyZpCz3mrH4= github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8= github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat636LX7Bqt5lYEZ27JNDcqxfjdBQuJ/MM4CN/Lzo= github.com/alessio/shellescape v1.4.2 h1:MHPfaU+ddJ0/bYWpgIeUnQUqKrlJ1S7BfEYPM4uEoM0= diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index be6c9560b..429323989 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -385,3 +385,12 @@ func DirSize(path string) (int64, error) { }) return size, err } + +// DirExists returns true if dir exists and is a directory. +func DirExists(dir string) bool { + fi, err := os.Stat(dir) + if err != nil { + return false + } + return fi.IsDir() +} diff --git a/libsq/libsq.go b/libsq/libsq.go index aa0a44b34..74befc8b1 100644 --- a/libsq/libsq.go +++ b/libsq/libsq.go @@ -11,6 +11,7 @@ package libsq import ( "context" + "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/errz" @@ -134,7 +135,7 @@ func SLQ2SQL(ctx context.Context, qc *QueryContext, query string) (targetSQL str // The caller is responsible for closing pool (and db, if non-nil). func QuerySQL(ctx context.Context, pool driver.Pool, db sqlz.DB, recw RecordWriter, query string, args ...any, -) error { +) error { //nolint:funlen log := lg.FromContext(ctx) errw := pool.SQLDriver().ErrWrapFunc() diff --git a/libsq/source/files.go b/libsq/source/files.go index a71dbb74d..e3047c67c 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -173,7 +173,7 @@ func (fs *Files) addStdin(ctx context.Context, f *os.File) error { lw := ioz.NewWrittenWriter(w) fs.stdinLength = lw.Written - //df := ioz.DelayReader(f, time.Microsecond*500, true) // FIXME: Delete + // df := ioz.DelayReader(f, time.Microsecond*500, true) // FIXME: Delete cr := contextio.NewReader(ctx, f) pw := progress.NewWriter(ctx, "Reading stdin", -1, lw) From 0fc01b127aed0518d179e9a4e1fa75a5ba68e846 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 12:46:06 -0700 Subject: [PATCH 044/195] More cache cmds --- cli/cli.go | 2 +- cli/cmd_config_cache.go | 75 +++++++++++++++++++++-------------------- libsq/libsq.go | 4 +-- 3 files changed, 41 insertions(+), 40 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index 0d801227b..75edf71e3 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -225,7 +225,7 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { addCmd(ru, configCmd, newConfigSetCmd()) addCmd(ru, configCmd, newConfigLocationCmd()) addCmd(ru, configCmd, newConfigEditCmd()) - cacheCmd := addCmd(ru, configCmd, newConfigCacheCmd()) + cacheCmd := addCmd(ru, rootCmd, newCacheCmd()) addCmd(ru, cacheCmd, newConfigCacheLocationCmd()) addCmd(ru, cacheCmd, newConfigCacheInfoCmd()) addCmd(ru, cacheCmd, newConfigCacheClearCmd()) diff --git a/cli/cmd_config_cache.go b/cli/cmd_config_cache.go index 084dedbf3..56f9482b5 100644 --- a/cli/cmd_config_cache.go +++ b/cli/cmd_config_cache.go @@ -1,12 +1,13 @@ package cli import ( - "github.com/a8m/tree" - "github.com/a8m/tree/ostree" "io" "os" "path/filepath" + "github.com/a8m/tree" + "github.com/a8m/tree/ostree" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" @@ -20,7 +21,7 @@ import ( "github.com/neilotoole/sq/cli/run" ) -func newConfigCacheCmd() *cobra.Command { +func newCacheCmd() *cobra.Command { cmd := &cobra.Command{ Use: "cache", Args: cobra.NoArgs, @@ -78,7 +79,7 @@ func newConfigCacheLocationCmd() *cobra.Command { Long: "Print cache location.", Args: cobra.ExactArgs(0), RunE: execConfigCacheLocation, - Example: ` $ sq config cache location + Example: ` $ sq cache location /Users/neilotoole/Library/Caches/sq`, } @@ -97,11 +98,11 @@ func execConfigCacheLocation(cmd *cobra.Command, _ []string) error { func newConfigCacheInfoCmd() *cobra.Command { cmd := &cobra.Command{ Use: "info", - Short: "Print cache info", - Long: "Print cache info.", + Short: "Show cache info", + Long: "Show cache info, including location and size.", Args: cobra.ExactArgs(0), RunE: execConfigCacheInfo, - Example: ` $ sq config cache info + Example: ` $ sq cache info /Users/neilotoole/Library/Caches/sq (1.2MB)`, } @@ -132,7 +133,7 @@ func newConfigCacheClearCmd() *cobra.Command { Long: "Clear cache. May cause issues if another sq instance is running.", Args: cobra.ExactArgs(0), RunE: execConfigCacheClear, - Example: ` $ sq config cache clear`, + Example: ` $ sq cache clear`, } return cmd @@ -173,7 +174,7 @@ func newConfigCacheTreeCmd() *cobra.Command { Long: "Print tree view of cache dir.", Args: cobra.ExactArgs(0), RunE: execConfigCacheTree, - Example: ` $ sq config cache tree`, + Example: ` $ sq cache tree`, } return cmd @@ -193,35 +194,35 @@ func printFileTree(w io.Writer, loc string) error { Fs: new(ostree.FS), OutFile: w, All: false, - //DirsOnly: false, - //FullPath: false, - //IgnoreCase: false, - //FollowLink: false, - //DeepLevel: 0, - //Pattern: "", - //IPattern: "", - //MatchDirs: false, - //Prune: false, - //ByteSize: false, - //UnitSize: true, - //FileMode: false, - //ShowUid: false, - //ShowGid: false, - //LastMod: false, - //Quotes: false, - //Inodes: false, - //Device: false, - //NoSort: false, - //VerSort: false, - //ModSort: false, - //DirSort: false, - //NameSort: false, - //SizeSort: false, - //CTimeSort: false, - //ReverSort: false, - //NoIndent: false, + // DirsOnly: false, + // FullPath: false, + // IgnoreCase: false, + // FollowLink: false, + // DeepLevel: 0, + // Pattern: "", + // IPattern: "", + // MatchDirs: false, + // Prune: false, + // ByteSize: false, + // UnitSize: true, + // FileMode: false, + // ShowUid: false, + // ShowGid: false, + // LastMod: false, + // Quotes: false, + // Inodes: false, + // Device: false, + // NoSort: false, + // VerSort: false, + // ModSort: false, + // DirSort: false, + // NameSort: false, + // SizeSort: false, + // CTimeSort: false, + // ReverSort: false, + // NoIndent: false, Colorize: true, - //Color: nil, + // Color: nil, } inf := tree.New(loc) diff --git a/libsq/libsq.go b/libsq/libsq.go index 74befc8b1..edc1a56c1 100644 --- a/libsq/libsq.go +++ b/libsq/libsq.go @@ -133,9 +133,9 @@ func SLQ2SQL(ctx context.Context, qc *QueryContext, query string) (targetSQL str // Note that QuerySQL may return before recw has finished writing, thus the // caller may wish to wait for recw to complete. // The caller is responsible for closing pool (and db, if non-nil). -func QuerySQL(ctx context.Context, pool driver.Pool, db sqlz.DB, +func QuerySQL(ctx context.Context, pool driver.Pool, db sqlz.DB, //nolint:funlen recw RecordWriter, query string, args ...any, -) error { //nolint:funlen +) error { log := lg.FromContext(ctx) errw := pool.SQLDriver().ErrWrapFunc() From bd1536376f09e7650fb274c8ab3e8e85da4feb56 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 14:40:59 -0700 Subject: [PATCH 045/195] More cache cmds --- cli/cli.go | 13 +- cli/cmd_cache.go | 210 +++++++++++++++++++++++++++++ cli/cmd_config_cache.go | 232 -------------------------------- cli/cmd_root.go | 3 - cli/flag/flag.go | 4 + cli/flags.go | 4 +- cli/output.go | 10 +- libsq/core/ioz/ioz.go | 68 +++++++++- libsq/core/progress/progress.go | 29 ++-- libsq/driver/driver.go | 2 + libsq/driver/sources.go | 20 ++- libsq/source/cache.go | 42 +++--- 12 files changed, 338 insertions(+), 299 deletions(-) create mode 100644 cli/cmd_cache.go delete mode 100644 cli/cmd_config_cache.go diff --git a/cli/cli.go b/cli/cli.go index 75edf71e3..369f4086d 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -77,7 +77,7 @@ func Execute(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { ctx = options.NewContext(ctx, ru.Config.Options) log := lg.FromContext(ctx) - log.Debug("EXECUTE", "args", strings.Join(args, " ")) + log.Info("EXECUTE", "args", strings.Join(args, " ")) log.Debug("Build info", "build", buildinfo.Get()) log.Debug("Config", "config.version", ru.Config.Version, @@ -225,11 +225,14 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { addCmd(ru, configCmd, newConfigSetCmd()) addCmd(ru, configCmd, newConfigLocationCmd()) addCmd(ru, configCmd, newConfigEditCmd()) + cacheCmd := addCmd(ru, rootCmd, newCacheCmd()) - addCmd(ru, cacheCmd, newConfigCacheLocationCmd()) - addCmd(ru, cacheCmd, newConfigCacheInfoCmd()) - addCmd(ru, cacheCmd, newConfigCacheClearCmd()) - addCmd(ru, cacheCmd, newConfigCacheTreeCmd()) + addCmd(ru, cacheCmd, newCacheLocationCmd()) + addCmd(ru, cacheCmd, newCacheInfoCmd()) + addCmd(ru, cacheCmd, newCacheEnableCmd()) + addCmd(ru, cacheCmd, newCacheDisableCmd()) + addCmd(ru, cacheCmd, newCacheClearCmd()) + addCmd(ru, cacheCmd, newCacheTreeCmd()) addCmd(ru, rootCmd, newCompletionCmd()) addCmd(ru, rootCmd, newVersionCmd()) diff --git a/cli/cmd_cache.go b/cli/cmd_cache.go new file mode 100644 index 000000000..ba0015753 --- /dev/null +++ b/cli/cmd_cache.go @@ -0,0 +1,210 @@ +package cli + +import ( + "os" + "path/filepath" + + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/neilotoole/sq/libsq/driver" + "github.com/neilotoole/sq/libsq/source" + "github.com/spf13/cobra" + + "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/cli/run" +) + +func newCacheCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "cache", + Args: cobra.NoArgs, + Short: "Manage cache", + Long: `Manage cache.`, + RunE: func(cmd *cobra.Command, args []string) error { + return cmd.Help() + }, + Example: ` # Print cache location. + $ sq cache location + + # Show cache info. + $ sq cache stat + + $ sq cache enable + + $ sq cache disable + + $ sq cache clear + + # Print tree view of cache dir. + $ sq cache tree`, + } + + return cmd +} + +func newCacheLocationCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "location", + Aliases: []string{"loc"}, + Short: "Print cache location", + Long: "Print cache location.", + Args: cobra.ExactArgs(0), + RunE: execCacheLocation, + Example: ` $ sq cache location + /Users/neilotoole/Library/Caches/sq`, + } + + addTextFormatFlags(cmd) + cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) + cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage) + return cmd +} + +func execCacheLocation(cmd *cobra.Command, _ []string) error { + dir := source.CacheDirPath() + ru := run.FromContext(cmd.Context()) + return ru.Writers.Config.CacheLocation(dir) +} + +func newCacheInfoCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "stat", + Short: "Show cache info", + Long: "Show cache info, including location and size.", + Args: cobra.ExactArgs(0), + RunE: execCacheInfo, + Example: ` $ sq cache stat + /Users/neilotoole/Library/Caches/sq enabled (472.8MB)`, + } + + addTextFormatFlags(cmd) + cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) + cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage) + return cmd +} + +func execCacheInfo(cmd *cobra.Command, _ []string) error { + dir := source.CacheDirPath() + ru := run.FromContext(cmd.Context()) + size, err := ioz.DirSize(dir) + if err != nil { + lg.FromContext(cmd.Context()).Warn("Could not determine cache size", + lga.Path, dir, lga.Err, err) + size = -1 // -1 tells the printer that the size is unavailable. + } + + enabled := driver.OptIngestCache.Get(ru.Config.Options) + return ru.Writers.Config.CacheInfo(dir, enabled, size) +} + +func newCacheClearCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "clear", + Short: "Clear cache", + Long: "Clear cache. May cause issues if another sq instance is running.", + Args: cobra.ExactArgs(0), + RunE: execCacheClear, + Example: ` $ sq cache clear`, + } + + return cmd +} + +func execCacheClear(cmd *cobra.Command, _ []string) error { + log := lg.FromContext(cmd.Context()) + cacheDir := source.CacheDirPath() + if !ioz.DirExists(cacheDir) { + return nil + } + + // Instead of directly deleting the existing cache dir, we first + // move it to /tmp, and then try to delete it. This should probably + // help with the situation where another sq instance has an open pid + // lock in the cache dir. + + tmpDir := source.TempDirPath() + if err := ioz.RequireDir(tmpDir); err != nil { + return errz.Wrap(err, "cache clear") + } + relocateDir := filepath.Join(tmpDir, "dead_cache_"+stringz.Uniq8()) + if err := os.Rename(cacheDir, relocateDir); err != nil { + return errz.Wrap(err, "cache clear: relocate") + } + + if err := os.RemoveAll(relocateDir); err != nil { + log.Warn("Could not delete relocated cache dir", lga.Path, relocateDir, lga.Err, err) + } + + // Recreate the cache dir. + if err := ioz.RequireDir(cacheDir); err != nil { + return errz.Wrap(err, "cache clear") + } + + return nil +} + +func newCacheTreeCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "tree", + Short: "Print tree view of cache dir", + Long: "Print tree view of cache dir.", + Args: cobra.ExactArgs(0), + RunE: execCacheTree, + Example: ` # Print cache tree + $ sq cache tree + + # Print cache tree with sizes + $ sq cache tree --size`, + } + + _ = cmd.Flags().BoolP(flag.CacheTreeSize, flag.CacheTreeSizeShort, false, flag.CacheTreeSizeUsage) + return cmd +} + +func execCacheTree(cmd *cobra.Command, _ []string) error { + ru := run.FromContext(cmd.Context()) + cacheDir := source.CacheDirPath() + if !ioz.DirExists(cacheDir) { + return nil + } + + showSize := cmdFlagBool(cmd, flag.CacheTreeSize) + return ioz.PrintTree(ru.Out, cacheDir, showSize) +} + +func newCacheEnableCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "enable", + Short: "Enable caching", + Long: `Disable caching. This is equivalent to: + + $ sq config set ingest.cache true`, + Args: cobra.ExactArgs(0), + RunE: func(cmd *cobra.Command, args []string) error { + return execConfigSet(cmd, []string{driver.OptIngestCache.Key(), "true"}) + }, + Example: ` $ sq cache enable`, + } + + return cmd +} + +func newCacheDisableCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "disable", + Short: "Disable caching", + Long: `Disable caching. This is equivalent to: + + $ sq config set ingest.cache false`, + Args: cobra.ExactArgs(0), + RunE: func(cmd *cobra.Command, args []string) error { + return execConfigSet(cmd, []string{driver.OptIngestCache.Key(), "false"}) + }, + Example: ` $ sq cache disable`, + } + + return cmd +} diff --git a/cli/cmd_config_cache.go b/cli/cmd_config_cache.go deleted file mode 100644 index 56f9482b5..000000000 --- a/cli/cmd_config_cache.go +++ /dev/null @@ -1,232 +0,0 @@ -package cli - -import ( - "io" - "os" - "path/filepath" - - "github.com/a8m/tree" - "github.com/a8m/tree/ostree" - - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/stringz" - "github.com/neilotoole/sq/libsq/driver" - "github.com/neilotoole/sq/libsq/source" - "github.com/spf13/cobra" - - "github.com/neilotoole/sq/cli/flag" - "github.com/neilotoole/sq/cli/run" -) - -func newCacheCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "cache", - Args: cobra.NoArgs, - Short: "Manage cache", - Long: `Manage cache.`, - RunE: func(cmd *cobra.Command, args []string) error { - return cmd.Help() - }, - // FIXME: add examples - // Example: ` # Print config location - //$ sq config location - // - //# Show base config - //$ sq config ls - // - //# Show base config including unset and default values. - //$ sq config ls -v - // - //# Show base config in maximum detail (YAML format) - //$ sq config ls -yv - // - //# Get base value of an option - //$ sq config get format - // - //# Get source-specific value of an option - //$ sq config get --src @sakila conn.max-open - // - //# Set base option value - //$ sq config set format json - // - //# Set source-specific option value - //$ sq config set --src @sakila conn.max-open 50 - // - //# Help for an option - //$ sq config set format --help - // - //# Edit base config in $EDITOR - //$ sq config edit - // - //# Edit config for source in $EDITOR - //$ sq config edit @sakila - // - //# Delete option (reset to default value) - //$ sq config set -D log.level`, - } - - return cmd -} - -func newConfigCacheLocationCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "location", - Aliases: []string{"loc"}, - Short: "Print cache location", - Long: "Print cache location.", - Args: cobra.ExactArgs(0), - RunE: execConfigCacheLocation, - Example: ` $ sq cache location - /Users/neilotoole/Library/Caches/sq`, - } - - addTextFormatFlags(cmd) - cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) - cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage) - return cmd -} - -func execConfigCacheLocation(cmd *cobra.Command, _ []string) error { - dir := source.CacheDirPath() - ru := run.FromContext(cmd.Context()) - return ru.Writers.Config.CacheLocation(dir) -} - -func newConfigCacheInfoCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "info", - Short: "Show cache info", - Long: "Show cache info, including location and size.", - Args: cobra.ExactArgs(0), - RunE: execConfigCacheInfo, - Example: ` $ sq cache info - /Users/neilotoole/Library/Caches/sq (1.2MB)`, - } - - addTextFormatFlags(cmd) - cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) - cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage) - return cmd -} - -func execConfigCacheInfo(cmd *cobra.Command, _ []string) error { - dir := source.CacheDirPath() - ru := run.FromContext(cmd.Context()) - size, err := ioz.DirSize(dir) - if err != nil { - lg.FromContext(cmd.Context()).Warn("Could not determine cache size", - lga.Path, dir, lga.Err, err) - size = -1 - } - - enabled := driver.OptIngestCache.Get(ru.Config.Options) - return ru.Writers.Config.CacheInfo(dir, enabled, size) -} - -func newConfigCacheClearCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "clear", - Short: "Clear cache", - Long: "Clear cache. May cause issues if another sq instance is running.", - Args: cobra.ExactArgs(0), - RunE: execConfigCacheClear, - Example: ` $ sq cache clear`, - } - - return cmd -} - -func execConfigCacheClear(cmd *cobra.Command, _ []string) error { - log := lg.FromContext(cmd.Context()) - cacheDir := source.CacheDirPath() - if !ioz.DirExists(cacheDir) { - return nil - } - - // Instead of directly deleting the existing cache dir, we first - // move it to /tmp, and then try to delete it. This should probably - // help with the situation where another sq instance has an open pid - // lock in the cache dir. - tmpLoc := filepath.Join(os.TempDir(), "sq", "dead_cache_"+stringz.Uniq8()) - if err := os.Rename(cacheDir, tmpLoc); err != nil { - return errz.Wrap(err, "clear cache: relocate") - } - - deleteErr := os.RemoveAll(tmpLoc) - if deleteErr != nil { - log.Warn("Could not delete relocated cache dir", lga.Path, tmpLoc, lga.Err, deleteErr) - } - - if err := os.MkdirAll(cacheDir, 0o750); err != nil { - return errz.Wrap(err, "clear cache") - } - - return nil -} - -func newConfigCacheTreeCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "tree", - Short: "Print tree view of cache dir", - Long: "Print tree view of cache dir.", - Args: cobra.ExactArgs(0), - RunE: execConfigCacheTree, - Example: ` $ sq cache tree`, - } - - return cmd -} - -func execConfigCacheTree(cmd *cobra.Command, _ []string) error { - ru := run.FromContext(cmd.Context()) - cacheDir := source.CacheDirPath() - if !ioz.DirExists(cacheDir) { - return nil - } - return printFileTree(ru.Out, cacheDir) -} - -func printFileTree(w io.Writer, loc string) error { - opts := &tree.Options{ - Fs: new(ostree.FS), - OutFile: w, - All: false, - // DirsOnly: false, - // FullPath: false, - // IgnoreCase: false, - // FollowLink: false, - // DeepLevel: 0, - // Pattern: "", - // IPattern: "", - // MatchDirs: false, - // Prune: false, - // ByteSize: false, - // UnitSize: true, - // FileMode: false, - // ShowUid: false, - // ShowGid: false, - // LastMod: false, - // Quotes: false, - // Inodes: false, - // Device: false, - // NoSort: false, - // VerSort: false, - // ModSort: false, - // DirSort: false, - // NameSort: false, - // SizeSort: false, - // CTimeSort: false, - // ReverSort: false, - // NoIndent: false, - Colorize: true, - // Color: nil, - } - - inf := tree.New(loc) - _, _ = inf.Visit(opts) - inf.Print(opts) - return nil -} diff --git a/cli/cmd_root.go b/cli/cmd_root.go index 90a4de350..9978c8bd3 100644 --- a/cli/cmd_root.go +++ b/cli/cmd_root.go @@ -51,9 +51,6 @@ See docs and more: https://sq.io`, # Output all rows from 'actor' table in JSON. $ sq -j .actor - # Alternative way to specify format. - $ sq --format json .actor - # Output in text format (with header). $ sq -th .actor diff --git a/cli/flag/flag.go b/cli/flag/flag.go index e0a895021..b99a459a4 100644 --- a/cli/flag/flag.go +++ b/cli/flag/flag.go @@ -82,6 +82,10 @@ const ( PasswordPromptShort = "p" PasswordPromptUsage = "Read password from stdin or prompt" + CacheTreeSize = "size" + CacheTreeSizeShort = "s" + CacheTreeSizeUsage = "Show sizes in cache tree" + Compact = "compact" CompactShort = "c" CompactUsage = "Compact instead of pretty-printed output" diff --git a/cli/flags.go b/cli/flags.go index 7d10f795e..a92fb89ce 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -56,10 +56,10 @@ func cmdFlagIsSetTrue(cmd *cobra.Command, name string) bool { return b } -// cmdFlagIsSetTrue returns the bool value of flag name. If the flag +// cmdFlagBool returns the bool value of flag name. If the flag // has not been set, its default value is returned. // Contrast with cmdFlagIsSetTrue. -func cmdFlagBool(cmd *cobra.Command, name string) bool { //nolint:unused +func cmdFlagBool(cmd *cobra.Command, name string) bool { b, err := cmd.Flags().GetBool(name) if err != nil { panic(err) // Should never happen diff --git a/cli/output.go b/cli/output.go index f12abba94..e5511cde6 100644 --- a/cli/output.go +++ b/cli/output.go @@ -107,8 +107,8 @@ command, sq falls back to "text". Available formats: true, 0, true, - "Specify whether a progress bar is shown for long-running operations", - `Specify whether a progress bar is shown for long-running operations.`, + "Progress bar shown for long-running operations", + `Progress bar shown for long-running operations.`, options.TagOutput, ) @@ -118,8 +118,7 @@ command, sq falls back to "text". Available formats: 0, time.Second*2, "Progress bar render delay", - `How long to wait after a long-running operation begins -before showing a progress bar.`, + `Delay before showing a progress bar.`, ) OptCompact = options.NewBool( @@ -464,9 +463,8 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer prog := progress.New(ctx, errOut2, renderDelay, progColors) // On first write to stdout, we remove the progress widget. - // out2 = ioz.NotifyOnceWriter(out2, prog.Wait) out2 = ioz.NotifyOnceWriter(out2, func() { - lg.FromContext(ctx).Debug("Notify once invoked") + lg.FromContext(ctx).Debug("Output stream is being written to; removing progress widget") prog.Wait() }) cmd.SetContext(progress.NewContext(ctx, prog)) diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 429323989..5e8d23151 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -14,6 +14,9 @@ import ( "sync/atomic" "time" + "github.com/a8m/tree" + "github.com/a8m/tree/ostree" + yaml "github.com/goccy/go-yaml" "github.com/neilotoole/sq/libsq/core/errz" @@ -264,24 +267,31 @@ func NotifyOnceWriter(w io.Writer, fn func()) io.Writer { } return ¬ifyOnceWriter{ - fn: fn, - w: w, + fn: fn, + w: w, + doneCh: make(chan struct{}), } } var _ io.Writer = (*notifyOnceWriter)(nil) type notifyOnceWriter struct { - fn func() w io.Writer + fn func() + doneCh chan struct{} notifyOnce sync.Once } // Write implements [io.Writer]. On the first invocation of this -// method, fn is invoked. +// method, the notify function is invoked, blocking until it returns. +// Subsequent invocations of Write do trigger the notify function. func (w *notifyOnceWriter) Write(p []byte) (n int, err error) { - w.notifyOnce.Do(w.fn) + w.notifyOnce.Do(func() { + close(w.doneCh) + w.fn() + }) + <-w.doneCh return w.w.Write(p) } @@ -386,6 +396,12 @@ func DirSize(path string) (int64, error) { return size, err } +// RequireDir ensures that dir exists and is a directory, creating +// it if necessary. +func RequireDir(dir string) error { + return errz.Err(os.MkdirAll(dir, 0o750)) +} + // DirExists returns true if dir exists and is a directory. func DirExists(dir string) bool { fi, err := os.Stat(dir) @@ -394,3 +410,45 @@ func DirExists(dir string) bool { } return fi.IsDir() } + +func PrintTree(w io.Writer, loc string, showSize bool) error { + opts := &tree.Options{ + Fs: new(ostree.FS), + OutFile: w, + All: false, + // DirsOnly: false, + // FullPath: false, + // IgnoreCase: false, + // FollowLink: false, + // DeepLevel: 0, + // Pattern: "", + // IPattern: "", + // MatchDirs: false, + // Prune: false, + // ByteSize: false, + UnitSize: showSize, + // FileMode: false, + // ShowUid: false, + // ShowGid: false, + // LastMod: false, + // Quotes: false, + // Inodes: false, + // Device: false, + // NoSort: false, + // VerSort: false, + // ModSort: false, + // DirSort: false, + // NameSort: false, + // SizeSort: false, + // CTimeSort: false, + // ReverSort: false, + // NoIndent: false, + Colorize: true, + // Color: nil, + } + + inf := tree.New(loc) + _, _ = inf.Visit(opts) + inf.Print(opts) + return nil +} diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 7e46fc13e..fb802c11e 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -60,7 +60,7 @@ const ( // The Progress is lazily initialized, and thus the delay clock doesn't // start ticking until the first call to one of the Progress.NewX methods. func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors) *Progress { - lg.FromContext(ctx).Error("New progress", "delay", delay) + lg.FromContext(ctx).Debug("New progress widget", "delay", delay) var cancelFn context.CancelFunc ogCtx := ctx @@ -72,15 +72,15 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors } p := &Progress{ - ctx: ctx, - mu: sync.Mutex{}, - colors: colors, - // cleanup: cleanup.New(), + ctx: ctx, + mu: &sync.Mutex{}, + colors: colors, cancelFn: cancelFn, bars: make([]*Bar, 0), } p.pcInit = func() { + lg.FromContext(ctx).Debug("Initializing progress widget") opts := []mpb.ContainerOption{ mpb.WithOutput(out), mpb.WithWidth(boxWidth), @@ -96,7 +96,6 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors close(delayCh) p.delayCh = delayCh } - lg.FromContext(ctx).Debug("Render delay", "delay", delay) p.pc = mpb.NewWithContext(ctx, opts...) p.pcInit = nil @@ -109,7 +108,7 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors // completion. type Progress struct { // mu guards ALL public methods. - mu sync.Mutex + mu *sync.Mutex ctx context.Context @@ -277,24 +276,16 @@ func (p *Progress) newBar(msg string, total int64, return b } -func (p *Progress) barStopped(b *Bar) { +// barStopped is called by a Bar when it is stopped. +// This was supposed to do something, but it's a no-op for now. +func (p *Progress) barStopped(_ *Bar) { if p == nil { return } - - p.mu.Lock() - defer p.mu.Unlock() - - for i, bar := range p.bars { - if bar == b { - p.bars = append(p.bars[:i], p.bars[i+1:]...) - return - } - } } func spinnerStyle(c *color.Color) mpb.SpinnerStyleComposer { - // TODO: should use ascii chars? + // REVISIT: maybe use ascii chars only, in case it's a weird terminal? frames := []string{"∙∙∙", "●∙∙", "●∙∙", "∙●∙", "∙●∙", "∙∙●", "∙∙●", "∙∙∙"} style := mpb.SpinnerStyle(frames...) if c != nil { diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index 48e7cc934..e47b8b64e 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -346,6 +346,8 @@ type SQLDriver interface { // connections. It is conceptually equivalent to // stdlib sql.DB, and in fact encapsulates a sql.DB instance. The // realized sql.DB instance can be accessed via the DB method. +// +// REVISIT: Rename Pool to Grip or some such? type Pool interface { // DB returns the sql.DB object for this Pool. // This operation can take a long time if opening the DB requires diff --git a/libsq/driver/sources.go b/libsq/driver/sources.go index 340c7b0d0..5f5fe8ba1 100644 --- a/libsq/driver/sources.go +++ b/libsq/driver/sources.go @@ -4,7 +4,6 @@ import ( "context" "errors" "log/slog" - "os" "path/filepath" "strings" "sync" @@ -115,11 +114,15 @@ func (ss *Sources) doOpen(ctx context.Context, src *source.Source) (Pool, error) func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Pool, error) { const msgCloseScratch = "Close scratch db" - _, srcCacheDBFilepath, _, err := ss.getCachePaths(src) + cacheDir, srcCacheDBFilepath, _, err := ss.getCachePaths(src) if err != nil { return nil, err } + if err = ioz.RequireDir(cacheDir); err != nil { + return nil, err + } + scratchSrc, cleanFn, err := ss.scratchSrcFn(ctx, srcCacheDBFilepath) if err != nil { // if err is non-nil, cleanup is guaranteed to be nil @@ -212,6 +215,10 @@ func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, return nil, err } + if err = ioz.RequireDir(cacheDir); err != nil { + return nil, err + } + log.Debug("Using cache dir", lga.Path, cacheDir) ingestFilePath, err := ss.files.Filepath(ctx, src) @@ -223,7 +230,7 @@ func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, impl Pool foundCached bool ) - if impl, foundCached, err = ss.OpenCachedFor(ctx, src); err != nil { + if impl, foundCached, err = ss.openCachedFor(ctx, src); err != nil { return nil, err } if foundCached { @@ -322,15 +329,14 @@ func (ss *Sources) getLockfileFor(src *source.Source) (lockfile.Lockfile, error) return "", err } - if err = os.MkdirAll(srcCacheDir, 0o750); err != nil { - return "", errz.Err(err) + if err = ioz.RequireDir(srcCacheDir); err != nil { + return "", err } lockPath := filepath.Join(srcCacheDir, "pid.lock") return lockfile.New(lockPath) } -// OpenCachedFor implements ScratchPoolOpener. -func (ss *Sources) OpenCachedFor(ctx context.Context, src *source.Source) (Pool, bool, error) { +func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Pool, bool, error) { _, cacheDBPath, checksumsPath, err := ss.getCachePaths(src) if err != nil { return nil, false, err diff --git a/libsq/source/cache.go b/libsq/source/cache.go index f5455e100..65a77d4c3 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -3,37 +3,32 @@ package source import ( "os" "path/filepath" + "strings" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/stringz" ) -// CacheDirFor gets the cache dir for handle, creating it if necessary. -// If handle is empty or invalid, a random value is generated. +// CacheDirFor gets the cache dir for handle. It is not guaranteed +// that the returned dir exists or is accessible. func CacheDirFor(src *Source) (dir string, err error) { handle := src.Handle - switch handle { - case "": - // FIXME: This is surely an error? - return "", errz.Errorf("open cache dir: empty handle") - // handle = "@cache_" + stringz.UniqN(32) - case StdinHandle: + if err = ValidHandle(handle); err != nil { + return "", errz.Wrapf(err, "cache dir: invalid handle: %s", handle) + } + + if handle == StdinHandle { // stdin is different input every time, so we need a unique // cache dir. In practice, stdin probably isn't using this function. handle += "_" + stringz.UniqN(32) - default: - if err = ValidHandle(handle); err != nil { - return "", errz.Wrapf(err, "open cache dir: invalid handle: %s", handle) - } } - dir = CacheDirPath() - sanitized := Handle2SafePath(handle) - hash := src.Hash() - dir = filepath.Join(dir, "sources", sanitized, hash) - if err = os.MkdirAll(dir, 0o750); err != nil { - return "", errz.Wrapf(err, "open cache dir: %s", dir) - } + dir = filepath.Join( + CacheDirPath(), + "sources", + filepath.Join(strings.Split(strings.TrimPrefix(handle, "@"), "/")...), + src.Hash(), + ) return dir, nil } @@ -47,10 +42,17 @@ func CacheDirPath() (dir string) { if dir, err = os.UserCacheDir(); err != nil { // Some systems may not have a user cache dir, so we fall back // to the system temp dir. - dir = filepath.Join(os.TempDir(), "sq", "cache") + dir = filepath.Join(TempDirPath(), "cache") return dir } dir = filepath.Join(dir, "sq") return dir } + +// TempDirPath returns the sq temp dir. This is generally +// in TEMP_DIR/sq. It is not guaranteed that the returned dir exists +// or is accessible. +func TempDirPath() (dir string) { + return filepath.Join(os.TempDir(), "sq") +} From 72978666fdc8b1d3ceeb5ab16c3fe535ec91830e Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 14:49:32 -0700 Subject: [PATCH 046/195] Fine-tuning cache cmdsd --- cli/cmd_cache.go | 2 +- libsq/core/ioz/ioz.go | 40 ++++++++-------------------------------- 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/cli/cmd_cache.go b/cli/cmd_cache.go index ba0015753..b3553531f 100644 --- a/cli/cmd_cache.go +++ b/cli/cmd_cache.go @@ -172,7 +172,7 @@ func execCacheTree(cmd *cobra.Command, _ []string) error { } showSize := cmdFlagBool(cmd, flag.CacheTreeSize) - return ioz.PrintTree(ru.Out, cacheDir, showSize) + return ioz.PrintTree(ru.Out, cacheDir, showSize, !ru.Writers.Printing.IsMonochrome()) } func newCacheEnableCmd() *cobra.Command { diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 5e8d23151..5f3754a9f 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -411,40 +411,16 @@ func DirExists(dir string) bool { return fi.IsDir() } -func PrintTree(w io.Writer, loc string, showSize bool) error { +// PrintTree prints the file tree structure at loc to w. +// This function uses the github.com/a8m/tree library, which is +// a Go implementation of the venerable "tree" command. +func PrintTree(w io.Writer, loc string, showSize, colorize bool) error { opts := &tree.Options{ - Fs: new(ostree.FS), - OutFile: w, - All: false, - // DirsOnly: false, - // FullPath: false, - // IgnoreCase: false, - // FollowLink: false, - // DeepLevel: 0, - // Pattern: "", - // IPattern: "", - // MatchDirs: false, - // Prune: false, - // ByteSize: false, + Fs: new(ostree.FS), + OutFile: w, + All: true, UnitSize: showSize, - // FileMode: false, - // ShowUid: false, - // ShowGid: false, - // LastMod: false, - // Quotes: false, - // Inodes: false, - // Device: false, - // NoSort: false, - // VerSort: false, - // ModSort: false, - // DirSort: false, - // NameSort: false, - // SizeSort: false, - // CTimeSort: false, - // ReverSort: false, - // NoIndent: false, - Colorize: true, - // Color: nil, + Colorize: colorize, } inf := tree.New(loc) From 9288ada0492a5ea442b3d9c687b203b965007bd9 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 20:32:43 -0700 Subject: [PATCH 047/195] Right, I think progress stuff is actually working --- README.md | 9 +- cli/cmd_slq.go | 4 +- cli/cmd_sql.go | 4 +- cli/cmd_xtest.go | 2 +- cli/diff/record.go | 2 +- cli/output.go | 4 +- drivers/xlsx/ingest.go | 4 +- grammar/SLQ.g4 | 4 +- libsq/core/lg/devlog/tint/handler_test.go | 572 ---------------------- libsq/core/progress/progress.go | 263 +++++++--- libsq/core/progress/progress_test.go | 6 +- libsq/dbwriter.go | 4 +- libsq/pipeline.go | 2 +- testh/testh.go | 2 +- 14 files changed, 218 insertions(+), 664 deletions(-) delete mode 100644 libsq/core/lg/devlog/tint/handler_test.go diff --git a/README.md b/README.md index bd87be215..86e8ee3df 100644 --- a/README.md +++ b/README.md @@ -309,7 +309,7 @@ See [CHANGELOG.md](./CHANGELOG.md). - Thanks to [Diego Souza](https://github.com/diegosouza) for creating the [Arch Linux package](https://aur.archlinux.org/packages/sq-bin), and [`@icp`](https://github.com/icp1994) for creating the [Void Linux package](https://github.com/void-linux/void-packages/blob/master/srcpkgs/sq/template). -- Much inspiration is owed to [jq](https://stedolan.github.io/jq/). +- Much inspiration is owed to [jq](https://jqlang.github.io/jq/). - See [`go.mod`](https://github.com/neilotoole/sq/blob/master/go.mod) for a list of third-party packages. - Additionally, `sq` incorporates modified versions of: @@ -319,8 +319,11 @@ See [CHANGELOG.md](./CHANGELOG.md). from [jOOQ](https://github.com/jooq/jooq), which in turn owe their heritage to earlier work on Sakila. - Date rendering via [`ncruces/go-strftime`](https://github.com/ncruces/go-strftime). -- The [`dolmen-go/contextio`](https://github.com/dolmen-go/contextio) package is - incorporated into the codebase (with modifications). +- A modified version [`dolmen-go/contextio`](https://github.com/dolmen-go/contextio) is + incorporated into the codebase. +- The [`log.devmode`](https://sq.io/docs/config#logdevmode) log format is + derived from [`lmittmann/tint`](https://github.com/lmittmann/tint). +- [`djherbis/fscache`](https://github.com/djherbis/fscache) is used for caching. ## Similar, related, or noteworthy projects diff --git a/cli/cmd_slq.go b/cli/cmd_slq.go index abe14a1d0..9c47435e5 100644 --- a/cli/cmd_slq.go +++ b/cli/cmd_slq.go @@ -154,7 +154,7 @@ func execSLQInsert(ctx context.Context, ru *run.Run, mArgs map[string]string, ) execErr := libsq.ExecuteSLQ(ctx, qc, slq, inserter) - affected, waitErr := inserter.Wait() // Wait for the writer to finish processing + affected, waitErr := inserter.Wait() // Stop for the writer to finish processing if execErr != nil { return errz.Wrapf(execErr, "insert %s.%s failed", destSrc.Handle, destTbl) } @@ -416,7 +416,7 @@ func extractFlagArgsValues(cmd *cobra.Command) (map[string]string, error) { // preprocessFlagArgVars is a hack to support the predefined // variables "--arg" mechanism. We implement the mechanism in alignment // with how jq does it: "--arg name value". -// See: https://stedolan.github.io/jq/manual/v1.6/ +// See: https://jqlang.github.io/jq/manual/v1.6/ // // For example: // diff --git a/cli/cmd_sql.go b/cli/cmd_sql.go index c4933e2d5..e2bcf9a0e 100644 --- a/cli/cmd_sql.go +++ b/cli/cmd_sql.go @@ -130,7 +130,7 @@ func execSQLPrint(ctx context.Context, ru *run.Run, fromSrc *source.Source) erro if err != nil { return err } - _, err = recw.Wait() // Wait for the writer to finish processing + _, err = recw.Wait() // Stop for the writer to finish processing return err } @@ -170,7 +170,7 @@ func execSQLInsert(ctx context.Context, ru *run.Run, return errz.Wrapf(err, "insert to {%s} failed", source.Target(destSrc, destTbl)) } - affected, err := inserter.Wait() // Wait for the writer to finish processing + affected, err := inserter.Wait() // Stop for the writer to finish processing if err != nil { return errz.Wrapf(err, "insert %s.%s failed", destSrc.Handle, destTbl) } diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go index 75048f96e..38254f076 100644 --- a/cli/cmd_xtest.go +++ b/cli/cmd_xtest.go @@ -76,6 +76,6 @@ LOOP: time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) //nolint:gosec } - pb.Wait() + pb.Stop() return nil } diff --git a/cli/diff/record.go b/cli/diff/record.go index 8fa3970c4..6501fc667 100644 --- a/cli/diff/record.go +++ b/cli/diff/record.go @@ -243,6 +243,6 @@ func (d *recWriter) Open(_ context.Context, _ context.CancelFunc, recMeta record // Wait implements libsq.RecordWriter. func (d *recWriter) Wait() (written int64, err error) { - // We don't actually use Wait(), so just return zero values. + // We don't actually use Stop(), so just return zero values. return 0, nil } diff --git a/cli/output.go b/cli/output.go index e5511cde6..4405ebcc2 100644 --- a/cli/output.go +++ b/cli/output.go @@ -424,7 +424,7 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer renderDelay := OptProgressDelay.Get(opts) prog := progress.New(ctx, errOut, renderDelay, progColors) // On first write to stdout, we remove the progress widget. - out2 = ioz.NotifyOnceWriter(out2, prog.Wait) + out2 = ioz.NotifyOnceWriter(out2, prog.Stop) cmd.SetContext(progress.NewContext(ctx, prog)) } @@ -465,7 +465,7 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer // On first write to stdout, we remove the progress widget. out2 = ioz.NotifyOnceWriter(out2, func() { lg.FromContext(ctx).Debug("Output stream is being written to; removing progress widget") - prog.Wait() + prog.Stop() }) cmd.SetContext(progress.NewContext(ctx, prog)) } diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 7e1676bd3..8583ddbac 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -233,7 +233,7 @@ func ingestSheetToTable(ctx context.Context, destPool driver.Pool, sheetTbl *she close(bi.RecordCh) return err } - + select { case <-ctx.Done(): close(bi.RecordCh) @@ -252,7 +252,7 @@ func ingestSheetToTable(ctx context.Context, destPool driver.Pool, sheetTbl *she close(bi.RecordCh) // Indicate that we're finished writing records - err = <-bi.ErrCh // Wait for bi to complete + err = <-bi.ErrCh // Stop for bi to complete if err != nil { return err } diff --git a/grammar/SLQ.g4 b/grammar/SLQ.g4 index bf42e4f8a..830e4dce5 100644 --- a/grammar/SLQ.g4 +++ b/grammar/SLQ.g4 @@ -150,7 +150,7 @@ The 'group_by' construct implments the SQL "GROUP BY" clause. Syonyms: - 'group_by' for jq interoperability. - https://stedolan.github.io/jq/manual/v1.6/#group_by(path_expression) + https://jqlang.github.io/jq/manual/v1.6/#group_by(path_expression) - 'gb' for brevity. */ @@ -188,7 +188,7 @@ The optional plus/minus tokens specify ASC or DESC order. Synonyms: - 'sort_by' for jq interoperability. - https://stedolan.github.io/jq/manual/v1.6/#sort,sort_by(path_expression) + https://jqlang.github.io/jq/manual/v1.6/#sort,sort_by(path_expression) - 'ob' for brevity. We do not implement a 'sort' synonym for the jq 'sort' function, because SQL diff --git a/libsq/core/lg/devlog/tint/handler_test.go b/libsq/core/lg/devlog/tint/handler_test.go deleted file mode 100644 index 5000e4a1c..000000000 --- a/libsq/core/lg/devlog/tint/handler_test.go +++ /dev/null @@ -1,572 +0,0 @@ -package tint_test - -import ( - "bytes" - "context" - "errors" - "io" - "log/slog" - "os" - "slices" - "strconv" - "strings" - "testing" - "time" - - "github.com/lmittmann/tint" -) - -var faketime = time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) - -func Example() { - slog.SetDefault(slog.New(tint.NewHandler(os.Stderr, &tint.Options{ - Level: slog.LevelDebug, - TimeFormat: time.Kitchen, - }))) - - slog.Info("Starting server", "addr", ":8080", "env", "production") - slog.Debug("Connected to DB", "db", "myapp", "host", "localhost:5432") - slog.Warn("Slow request", "method", "GET", "path", "/users", "duration", 497*time.Millisecond) - slog.Error("DB connection lost", tint.Err(errors.New("connection reset")), "db", "myapp") - // Output: -} - -// Run test with "faketime" tag: -// -// TZ="" go test -tags=faketime -func TestHandler(t *testing.T) { - if !faketime.Equal(time.Now()) { - t.Skip(`skipping test; run with "-tags=faketime"`) - } - - tests := []struct { - Opts *tint.Options - F func(l *slog.Logger) - Want string - }{ - { - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `Nov 10 23:00:00.000 INF test key=val`, - }, - { - F: func(l *slog.Logger) { - l.Error("test", tint.Err(errors.New("fail"))) - }, - Want: `Nov 10 23:00:00.000 ERR test err=fail`, - }, - { - F: func(l *slog.Logger) { - l.Info("test", slog.Group("group", slog.String("key", "val"), tint.Err(errors.New("fail")))) - }, - Want: `Nov 10 23:00:00.000 INF test group.key=val group.err=fail`, - }, - { - F: func(l *slog.Logger) { - l.WithGroup("group").Info("test", "key", "val") - }, - Want: `Nov 10 23:00:00.000 INF test group.key=val`, - }, - { - F: func(l *slog.Logger) { - l.With("key", "val").Info("test", "key2", "val2") - }, - Want: `Nov 10 23:00:00.000 INF test key=val key2=val2`, - }, - { - F: func(l *slog.Logger) { - l.Info("test", "k e y", "v a l") - }, - Want: `Nov 10 23:00:00.000 INF test "k e y"="v a l"`, - }, - { - F: func(l *slog.Logger) { - l.WithGroup("g r o u p").Info("test", "key", "val") - }, - Want: `Nov 10 23:00:00.000 INF test "g r o u p.key"=val`, - }, - { - F: func(l *slog.Logger) { - l.Info("test", "slice", []string{"a", "b", "c"}, "map", map[string]int{"a": 1, "b": 2, "c": 3}) - }, - Want: `Nov 10 23:00:00.000 INF test slice="[a b c]" map="map[a:1 b:2 c:3]"`, - }, - { - Opts: &tint.Options{ - AddSource: true, - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `Nov 10 23:00:00.000 INF tint/handler_test.go:100 test key=val`, - }, - { - Opts: &tint.Options{ - TimeFormat: time.Kitchen, - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `11:00PM INF test key=val`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: drop(slog.TimeKey), - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `INF test key=val`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: drop(slog.LevelKey), - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `Nov 10 23:00:00.000 test key=val`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: drop(slog.MessageKey), - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `Nov 10 23:00:00.000 INF key=val`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: drop(slog.TimeKey, slog.LevelKey, slog.MessageKey), - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `key=val`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: drop("key"), - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `Nov 10 23:00:00.000 INF test`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: drop("key"), - }, - F: func(l *slog.Logger) { - l.WithGroup("group").Info("test", "key", "val", "key2", "val2") - }, - Want: `Nov 10 23:00:00.000 INF test group.key=val group.key2=val2`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - if a.Key == "key" && len(groups) == 1 && groups[0] == "group" { - return slog.Attr{} - } - return a - }, - }, - F: func(l *slog.Logger) { - l.WithGroup("group").Info("test", "key", "val", "key2", "val2") - }, - Want: `Nov 10 23:00:00.000 INF test group.key2=val2`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: replace(slog.IntValue(42), slog.TimeKey), - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `42 INF test key=val`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: replace(slog.StringValue("INFO"), slog.LevelKey), - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `Nov 10 23:00:00.000 INFO test key=val`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: replace(slog.IntValue(42), slog.MessageKey), - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: `Nov 10 23:00:00.000 INF 42 key=val`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: replace(slog.IntValue(42), "key"), - }, - F: func(l *slog.Logger) { - l.With("key", "val").Info("test", "key2", "val2") - }, - Want: `Nov 10 23:00:00.000 INF test key=42 key2=val2`, - }, - { - Opts: &tint.Options{ - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - return slog.Attr{} - }, - }, - F: func(l *slog.Logger) { - l.Info("test", "key", "val") - }, - Want: ``, - }, - { - F: func(l *slog.Logger) { - l.Info("test", "key", "") - }, - Want: `Nov 10 23:00:00.000 INF test key=""`, - }, - { - F: func(l *slog.Logger) { - l.Info("test", "", "val") - }, - Want: `Nov 10 23:00:00.000 INF test ""=val`, - }, - { - F: func(l *slog.Logger) { - l.Info("test", "", "") - }, - Want: `Nov 10 23:00:00.000 INF test ""=""`, - }, - - { // https://github.com/lmittmann/tint/issues/8 - F: func(l *slog.Logger) { - l.Log(context.TODO(), slog.LevelInfo+1, "test") - }, - Want: `Nov 10 23:00:00.000 INF+1 test`, - }, - { - Opts: &tint.Options{ - Level: slog.LevelDebug - 1, - }, - F: func(l *slog.Logger) { - l.Log(context.TODO(), slog.LevelDebug-1, "test") - }, - Want: `Nov 10 23:00:00.000 DBG-1 test`, - }, - { // https://github.com/lmittmann/tint/issues/12 - F: func(l *slog.Logger) { - l.Error("test", slog.Any("error", errors.New("fail"))) - }, - Want: `Nov 10 23:00:00.000 ERR test error=fail`, - }, - { // https://github.com/lmittmann/tint/issues/15 - F: func(l *slog.Logger) { - l.Error("test", tint.Err(nil)) - }, - Want: `Nov 10 23:00:00.000 ERR test err=`, - }, - { // https://github.com/lmittmann/tint/pull/26 - Opts: &tint.Options{ - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - if a.Key == slog.TimeKey && len(groups) == 0 { - return slog.Time(slog.TimeKey, a.Value.Time().Add(24*time.Hour)) - } - return a - }, - }, - F: func(l *slog.Logger) { - l.Error("test") - }, - Want: `Nov 11 23:00:00.000 ERR test`, - }, - { // https://github.com/lmittmann/tint/pull/27 - F: func(l *slog.Logger) { - l.Info("test", "a", "b", slog.Group("", slog.String("c", "d")), "e", "f") - }, - Want: `Nov 10 23:00:00.000 INF test a=b c=d e=f`, - }, - { // https://github.com/lmittmann/tint/pull/30 - // drop built-in attributes in a grouped log - Opts: &tint.Options{ - ReplaceAttr: drop(slog.TimeKey, slog.LevelKey, slog.MessageKey, slog.SourceKey), - AddSource: true, - }, - F: func(l *slog.Logger) { - l.WithGroup("group").Info("test", "key", "val") - }, - Want: `group.key=val`, - }, - { // https://github.com/lmittmann/tint/issues/36 - Opts: &tint.Options{ - ReplaceAttr: func(g []string, a slog.Attr) slog.Attr { - if len(g) == 0 && a.Key == slog.LevelKey { - _ = a.Value.Any().(slog.Level) - } - return a - }, - }, - F: func(l *slog.Logger) { - l.Info("test") - }, - Want: `Nov 10 23:00:00.000 INF test`, - }, - { // https://github.com/lmittmann/tint/issues/37 - Opts: &tint.Options{ - AddSource: true, - ReplaceAttr: func(g []string, a slog.Attr) slog.Attr { - return a - }, - }, - F: func(l *slog.Logger) { - l.Info("test") - }, - Want: `Nov 10 23:00:00.000 INF tint/handler_test.go:327 test`, - }, - { // https://github.com/lmittmann/tint/issues/44 - F: func(l *slog.Logger) { - l = l.WithGroup("group") - l.Error("test", tint.Err(errTest)) - }, - Want: `Nov 10 23:00:00.000 ERR test group.err=fail`, - }, - } - - for i, test := range tests { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var buf bytes.Buffer - if test.Opts == nil { - test.Opts = &tint.Options{} - } - test.Opts.NoColor = true - l := slog.New(tint.NewHandler(&buf, test.Opts)) - test.F(l) - - got := strings.TrimRight(buf.String(), "\n") - if test.Want != got { - t.Fatalf("(-want +got)\n- %s\n+ %s", test.Want, got) - } - }) - } -} - -// drop returns a ReplaceAttr that drops the given keys. -func drop(keys ...string) func([]string, slog.Attr) slog.Attr { - return func(groups []string, a slog.Attr) slog.Attr { - if len(groups) > 0 { - return a - } - - for _, key := range keys { - if a.Key == key { - a = slog.Attr{} - } - } - return a - } -} - -func replace(new slog.Value, keys ...string) func([]string, slog.Attr) slog.Attr { - return func(groups []string, a slog.Attr) slog.Attr { - if len(groups) > 0 { - return a - } - - for _, key := range keys { - if a.Key == key { - a.Value = new - } - } - return a - } -} - -func TestReplaceAttr(t *testing.T) { - tests := [][]any{ - {}, - {"key", "val"}, - {"key", "val", slog.Group("group", "key2", "val2")}, - {"key", "val", slog.Group("group", "key2", "val2", slog.Group("group2", "key3", "val3"))}, - } - - type replaceAttrParams struct { - Groups []string - Attr slog.Attr - } - - replaceAttrRecorder := func(record *[]replaceAttrParams) func([]string, slog.Attr) slog.Attr { - return func(groups []string, a slog.Attr) slog.Attr { - *record = append(*record, replaceAttrParams{groups, a}) - return a - } - } - - for i, test := range tests { - t.Run(strconv.Itoa(i), func(t *testing.T) { - slogRecord := make([]replaceAttrParams, 0) - slogLogger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{ - ReplaceAttr: replaceAttrRecorder(&slogRecord), - })) - slogLogger.Log(context.TODO(), slog.LevelInfo, "", test...) - - tintRecord := make([]replaceAttrParams, 0) - tintLogger := slog.New(tint.NewHandler(io.Discard, &tint.Options{ - ReplaceAttr: replaceAttrRecorder(&tintRecord), - })) - tintLogger.Log(context.TODO(), slog.LevelInfo, "", test...) - - if !slices.EqualFunc(slogRecord, tintRecord, func(a, b replaceAttrParams) bool { - return slices.Equal(a.Groups, b.Groups) && a.Attr.Equal(b.Attr) - }) { - t.Fatalf("(-want +got)\n- %v\n+ %v", slogRecord, tintRecord) - } - }) - } -} - -// See https://github.com/golang/exp/blob/master/slog/benchmarks/benchmarks_test.go#L25 -// -// Run e.g.: -// -// go test -bench=. -count=10 | benchstat -col /h /dev/stdin -func BenchmarkLogAttrs(b *testing.B) { - handler := []struct { - Name string - H slog.Handler - }{ - {"tint", tint.NewHandler(io.Discard, nil)}, - {"text", slog.NewTextHandler(io.Discard, nil)}, - {"json", slog.NewJSONHandler(io.Discard, nil)}, - {"discard", new(discarder)}, - } - - benchmarks := []struct { - Name string - F func(*slog.Logger) - }{ - { - "5 args", - func(logger *slog.Logger) { - logger.LogAttrs(context.TODO(), slog.LevelInfo, testMessage, - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - ) - }, - }, - { - "5 args custom level", - func(logger *slog.Logger) { - logger.LogAttrs(context.TODO(), slog.LevelInfo+1, testMessage, - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - ) - }, - }, - { - "10 args", - func(logger *slog.Logger) { - logger.LogAttrs(context.TODO(), slog.LevelInfo, testMessage, - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - ) - }, - }, - { - "40 args", - func(logger *slog.Logger) { - logger.LogAttrs(context.TODO(), slog.LevelInfo, testMessage, - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - slog.String("string", testString), - slog.Int("status", testInt), - slog.Duration("duration", testDuration), - slog.Time("time", testTime), - slog.Any("error", errTest), - ) - }, - }, - } - - for _, h := range handler { - b.Run("h="+h.Name, func(b *testing.B) { - for _, bench := range benchmarks { - b.Run(bench.Name, func(b *testing.B) { - b.ReportAllocs() - logger := slog.New(h.H) - for i := 0; i < b.N; i++ { - bench.F(logger) - } - }) - } - }) - } -} - -// discarder is a slog.Handler that discards all records. -type discarder struct{} - -func (*discarder) Enabled(context.Context, slog.Level) bool { return true } -func (*discarder) Handle(context.Context, slog.Record) error { return nil } -func (d *discarder) WithAttrs(attrs []slog.Attr) slog.Handler { return d } -func (d *discarder) WithGroup(name string) slog.Handler { return d } - -var ( - testMessage = "Test logging, but use a somewhat realistic message length." - testTime = time.Date(2022, time.May, 1, 0, 0, 0, 0, time.UTC) - testString = "7e3b3b2aaeff56a7108fe11e154200dd/7819479873059528190" - testInt = 32768 - testDuration = 23 * time.Second - errTest = errors.New("fail") -) diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index fb802c11e..127315fde 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -1,9 +1,24 @@ +// Package progress contains progress bar widget functionality. +// Use progress.New to create a new progress widget container. +// That widget should be added to a context using progress.NewContext, +// and retrieved via progress.FromContext. Invoke one of the Progress.NewX +// methods to create a new progress.Bar. Invoke Bar.IncrBy to increment +// the bar's progress, and invoke Bar.Stop to stop the bar. Be sure +// to invoke Progress.Stop when the progress widget is no longer needed. +// +// You can use the progress.NewReader and progress.NewWriter functions +// to wrap an io.Reader or io.Writer, respectively, with a progress bar. +// Both functions expect the supplied ctx arg to contain a *progress.Progress. +// Note also that both wrappers are context-aware; that is, they will stop +// the reading/writing process when the context is canceled. Be sure to +// call Close on the wrappers when done. package progress import ( "context" "io" "sync" + "sync/atomic" "time" humanize "github.com/dustin/go-humanize" @@ -15,15 +30,15 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" ) -type runKey struct{} +type ctxKey struct{} -// NewContext returns ctx with prog added as a value. -func NewContext(ctx context.Context, prog *Progress) context.Context { +// NewContext returns ctx with p added as a value. +func NewContext(ctx context.Context, p *Progress) context.Context { if ctx == nil { ctx = context.Background() } - return context.WithValue(ctx, runKey{}, prog) + return context.WithValue(ctx, ctxKey{}, p) } // FromContext returns the [Progress] added to ctx via NewContext, @@ -34,7 +49,7 @@ func FromContext(ctx context.Context) *Progress { return nil } - val := ctx.Value(runKey{}) + val := ctx.Value(ctxKey{}) if val == nil { return nil } @@ -52,10 +67,23 @@ const ( refreshRate = 150 * time.Millisecond ) +// NOTE: The implementation below is wildly more complicated than it should be. +// This is due to a bug in the mpb package, wherein it doesn't fully +// respect the render delay. +// +// https://github.com/vbauerster/mpb/issues/136 +// +// Until that bug is fixed, we have a messy workaround. The gist of it +// is that both the Progress.pc and Bar.bar are lazily initialized. +// The Progress.pc (progress container) is initialized on the first +// call to one of the Progress.NewX methods. The Bar.bar is initialized +// only after the render delay has expired. The details are ugly. Hopefully +// this can all be simplified once the mpb bug is fixed. + // New returns a new Progress instance, which is a container for progress bars. // The returned Progress instance is safe for concurrent use, and all of its // public methods can be safely invoked on a nil Progress. The caller is -// responsible for calling [Progress.Wait] on the returned Progress. +// responsible for calling [Progress.Stop] on the returned Progress. // Arg delay specifies a duration to wait before rendering the progress bar. // The Progress is lazily initialized, and thus the delay clock doesn't // start ticking until the first call to one of the Progress.NewX methods. @@ -88,7 +116,7 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } if delay > 0 { - delayCh := renderDelay(ctx, delay) + delayCh := renderDelay(ctx, p, delay) opts = append(opts, mpb.WithRenderDelay(delayCh)) p.delayCh = delayCh } else { @@ -104,13 +132,14 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors } // Progress represents a container that renders one or more progress bars. -// The caller is responsible for calling [Progress.Wait] to indicate +// The caller is responsible for calling [Progress.Stop] to indicate // completion. type Progress struct { // mu guards ALL public methods. mu *sync.Mutex - ctx context.Context + ctx context.Context + cancelFn context.CancelFunc // pc is the underlying progress container. It is lazily initialized // by pcInit. Any method that accesses pc must be certain that @@ -124,17 +153,20 @@ type Progress struct { // start as soon as delayCh is closed. delayCh <-chan struct{} + // stopped is set to true when Stop is called. + // REVISIT: Do we really need stopped, or can we rely on ctx.Done()? + stopped bool + colors *Colors - // cleanup *cleanup.Cleanup - bars []*Bar - cancelFn context.CancelFunc + // bars contains all bars that have been created on this Progress. + bars []*Bar } -// Wait waits for all bars to complete and finally shuts down the +// Stop waits for all bars to complete and finally shuts down the // container. After this method has been called, there is no way // to reuse the Progress instance. -func (p *Progress) Wait() { +func (p *Progress) Stop() { if p == nil { return } @@ -142,6 +174,13 @@ func (p *Progress) Wait() { p.mu.Lock() defer p.mu.Unlock() + if p.stopped { + return + } + + p.stopped = true + p.cancelFn() + if p.pc == nil { return } @@ -150,19 +189,43 @@ func (p *Progress) Wait() { return } - p.cancelFn() - - for _, bar := range p.bars { - bar.bar.Abort(true) + for _, b := range p.bars { + if b.bar != nil { + b.bar.Abort(true) + } } - for _, bar := range p.bars { - bar.bar.Wait() + for _, b := range p.bars { + if b.bar != nil { + b.bar.Wait() + } } p.pc.Wait() } +// initBars lazily initializes all bars in p.bars. +func (p *Progress) initBars() { + p.mu.Lock() + defer p.mu.Unlock() + + select { + case <-p.ctx.Done(): + return + default: + } + + if p.stopped { + return + } + + for _, b := range p.bars { + if !b.stopped { + b.initBarOnce.Do(b.initBar) + } + } +} + // NewUnitCounter returns a new indeterminate bar whose label // metric is the plural of the provided unit. The caller is ultimately // responsible for calling [Bar.Stop] on the returned Bar. However, @@ -261,53 +324,35 @@ func (p *Progress) newBar(msg string, total int64, total = 0 } - bar := p.pc.New(total, - style, - mpb.BarWidth(barWidth), - mpb.PrependDecorators( - colorize(decor.Name(msg, decor.WCSyncWidthR), p.colors.Message), - ), - mpb.AppendDecorators(decorators...), - mpb.BarRemoveOnComplete(), - ) - - b := &Bar{p: p, bar: bar} - p.bars = append(p.bars, b) - return b -} - -// barStopped is called by a Bar when it is stopped. -// This was supposed to do something, but it's a no-op for now. -func (p *Progress) barStopped(_ *Bar) { - if p == nil { - return + b := &Bar{ + p: p, + incrStash: &atomic.Int64{}, + initBarOnce: &sync.Once{}, } -} - -func spinnerStyle(c *color.Color) mpb.SpinnerStyleComposer { - // REVISIT: maybe use ascii chars only, in case it's a weird terminal? - frames := []string{"∙∙∙", "●∙∙", "●∙∙", "∙●∙", "∙●∙", "∙∙●", "∙∙●", "∙∙∙"} - style := mpb.SpinnerStyle(frames...) - if c != nil { - style = style.Meta(func(s string) string { - return c.Sprint(s) - }) + b.initBar = func() { + if b.stopped || p.stopped { + return + } + b.bar = p.pc.New(total, + style, + mpb.BarWidth(barWidth), + mpb.PrependDecorators( + colorize(decor.Name(msg, decor.WCSyncWidthR), p.colors.Message), + ), + mpb.AppendDecorators(decorators...), + mpb.BarRemoveOnComplete(), + ) + b.bar.IncrBy(int(b.incrStash.Load())) + b.incrStash.Store(0) } - return style -} -func barStyle(c *color.Color) mpb.BarStyleComposer { - clr := func(s string) string { - return c.Sprint(s) + p.bars = append(p.bars, b) + select { + case <-p.delayCh: + b.initBarOnce.Do(b.initBar) + default: } - - frames := []string{"∙", "●", "●", "●", "∙"} - - return mpb.BarStyle(). - Lbound(" ").Rbound(" "). - Filler("∙").FillerMeta(clr). - Padding(" "). - Tip(frames...).TipMeta(clr) + return b } // Bar represents a single progress bar. The caller should invoke @@ -315,8 +360,28 @@ func barStyle(c *color.Color) mpb.BarStyleComposer { // the bar is complete, the caller should invoke [Bar.Stop]. All // methods are safe to call on a nil Bar. type Bar struct { - p *Progress + + // bar is nil until barInitOnce.Do(initBar) is called bar *mpb.Bar + // p is never nil + p *Progress + + // There's a bug in the mpb package, wherein it doesn't fully + // respect the render delay. + // + // https://github.com/vbauerster/mpb/issues/136 + // + // Until that bug is fixed, the Bar is lazily initialized + // after the render delay expires. + + initBarOnce *sync.Once + initBar func() + + // incrStash holds the increment count until the + // bar is fully initialized. + incrStash *atomic.Int64 + + stopped bool } // IncrBy increments progress by amount of n. It is safe to @@ -325,7 +390,26 @@ func (b *Bar) IncrBy(n int) { if b == nil { return } - b.bar.IncrBy(n) + + b.p.mu.Lock() + defer b.p.mu.Unlock() + + if b.stopped || b.p.stopped { + return + } + + select { + case <-b.p.ctx.Done(): + return + case <-b.p.delayCh: + b.initBarOnce.Do(b.initBar) + if b.bar != nil { + b.bar.IncrBy(n) + } + return + default: + b.incrStash.Add(int64(n)) + } } // Stop stops and removes the bar. It is safe to call Stop on a nil Bar, @@ -334,15 +418,27 @@ func (b *Bar) Stop() { if b == nil { return } - b.bar.SetTotal(-1, true) - b.bar.Abort(true) + + b.p.mu.Lock() + defer b.p.mu.Unlock() + + if b.bar == nil { + b.stopped = true + return + } + + if !b.stopped { + b.bar.SetTotal(-1, true) + b.bar.Abort(true) + } + b.stopped = true + b.bar.Wait() - b.p.barStopped(b) } // renderDelay returns a channel that will be closed after d, -// or if ctx is done. -func renderDelay(ctx context.Context, d time.Duration) <-chan struct{} { +// or if ctx is done. Arg callback is invoked after the delay. +func renderDelay(ctx context.Context, p *Progress, d time.Duration) <-chan struct{} { ch := make(chan struct{}) t := time.NewTimer(d) go func() { @@ -353,6 +449,7 @@ func renderDelay(ctx context.Context, d time.Duration) <-chan struct{} { lg.FromContext(ctx).Debug("Render delay via ctx.Done") case <-t.C: lg.FromContext(ctx).Debug("Render delay via timer") + p.initBars() } }() return ch @@ -401,3 +498,29 @@ func (c *Colors) EnableColor(enable bool) { c.Size.DisableColor() c.Percent.DisableColor() } + +func spinnerStyle(c *color.Color) mpb.SpinnerStyleComposer { + // REVISIT: maybe use ascii chars only, in case it's a weird terminal? + frames := []string{"∙∙∙", "●∙∙", "●∙∙", "∙●∙", "∙●∙", "∙∙●", "∙∙●", "∙∙∙"} + style := mpb.SpinnerStyle(frames...) + if c != nil { + style = style.Meta(func(s string) string { + return c.Sprint(s) + }) + } + return style +} + +func barStyle(c *color.Color) mpb.BarStyleComposer { + clr := func(s string) string { + return c.Sprint(s) + } + + frames := []string{"∙", "●", "●", "●", "∙"} + + return mpb.BarStyle(). + Lbound(" ").Rbound(" "). + Filler("∙").FillerMeta(clr). + Padding(" "). + Tip(frames...).TipMeta(clr) +} diff --git a/libsq/core/progress/progress_test.go b/libsq/core/progress/progress_test.go index 7284f5ad0..37aaf19e1 100644 --- a/libsq/core/progress/progress_test.go +++ b/libsq/core/progress/progress_test.go @@ -33,7 +33,7 @@ func TestNewWriter(t *testing.T) { written, err := io.Copy(w, src) require.NoError(t, err) require.Equal(t, int64(limit), written) - pb.Wait() + pb.Stop() } // TestNewWriter_Closer tests that the returned writer @@ -45,7 +45,7 @@ func TestNewWriter_Closer(t *testing.T) { ctx := context.Background() pb := progress.New(ctx, os.Stdout, time.Millisecond, progress.DefaultColors()) ctx = progress.NewContext(ctx, pb) - defer pb.Wait() + defer pb.Stop() // bytes.Buffer doesn't implement io.Closer buf := &bytes.Buffer{} @@ -72,7 +72,7 @@ func TestNewReader_Closer(t *testing.T) { ctx := context.Background() pb := progress.New(ctx, os.Stdout, time.Millisecond, progress.DefaultColors()) ctx = progress.NewContext(ctx, pb) - defer pb.Wait() + defer pb.Stop() // bytes.Buffer doesn't implement io.Closer buf := &bytes.Buffer{} diff --git a/libsq/dbwriter.go b/libsq/dbwriter.go index c01f9c473..7fe60ed85 100644 --- a/libsq/dbwriter.go +++ b/libsq/dbwriter.go @@ -142,7 +142,7 @@ func (w *DBWriter) Open(ctx context.Context, cancelFn context.CancelFunc, recMet defer func() { // When the inserter goroutine finishes: // - we close errCh (indicates that the DBWriter is done) - // - and mark wg as done, which the Wait method depends upon. + // - and mark wg as done, which the Stop method depends upon. close(w.errCh) w.wg.Done() }() @@ -164,7 +164,7 @@ func (w *DBWriter) Open(ctx context.Context, cancelFn context.CancelFunc, recMet // Tell batch inserter that we're done sending records close(w.bi.RecordCh) - err = <-w.bi.ErrCh // Wait for batch inserter to complete + err = <-w.bi.ErrCh // Stop for batch inserter to complete if err != nil { lg.FromContext(ctx).Error(err.Error()) w.addErrs(err) diff --git a/libsq/pipeline.go b/libsq/pipeline.go index df9a1e4d6..e072b07a4 100644 --- a/libsq/pipeline.go +++ b/libsq/pipeline.go @@ -481,7 +481,7 @@ func execCopyTable(ctx context.Context, fromDB driver.Pool, fromTbl tablefq.T, return errz.Wrapf(err, "insert %s.%s failed", destPool.Source().Handle, destTbl) } - affected, err := inserter.Wait() // Wait for the writer to finish processing + affected, err := inserter.Wait() // Stop for the writer to finish processing if err != nil { return errz.Wrapf(err, "insert %s.%s failed", destPool.Source().Handle, destTbl) } diff --git a/testh/testh.go b/testh/testh.go index 3caaf458c..ee75fb758 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -527,7 +527,7 @@ func (h *Helper) Insert(src *source.Source, tbl string, cols []string, records . close(bi.RecordCh) // Indicate that we're finished writing records - err = <-bi.ErrCh // Wait for bi to complete + err = <-bi.ErrCh // Stop for bi to complete require.NoError(h.T, err) h.T.Logf("Inserted %d rows to %s.%s", bi.Written(), src.Handle, tbl) From e65cbc825e253af74c2dcbf769d715efe8a3379d Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 21:46:43 -0700 Subject: [PATCH 048/195] Refactoring --- cli/cli_test.go | 4 +- cli/cmd_add_test.go | 4 +- cli/cmd_cache.go | 18 +- cli/cmd_inspect_test.go | 14 +- cli/cmd_mv_test.go | 4 +- cli/cmd_sql_test.go | 6 +- cli/cobraz/cobraz_test.go | 4 +- cli/complete_location_test.go | 52 ++-- cli/complete_test.go | 6 +- cli/config/yamlstore/internal_test.go | 4 +- .../upgrades/v0.34.0/upgrade_test.go | 4 +- cli/config/yamlstore/yamlstore_test.go | 6 +- cli/output/adapter_test.go | 4 +- cli/output/tablew/configwriter.go | 5 +- cli/output/xlsxw/xlsxw_test.go | 6 +- cli/run.go | 2 +- cli/run/run.go | 8 +- drivers/csv/csv_test.go | 2 +- drivers/csv/detect_header_test.go | 4 +- drivers/csv/internal_test.go | 4 +- drivers/json/ingest_test.go | 10 +- drivers/json/json_test.go | 4 +- drivers/mysql/internal_test.go | 4 +- drivers/sqlite3/extension_test.go | 4 +- .../internal/sqlparser/sqlparser_test.go | 4 +- drivers/sqlite3/sqlite3_test.go | 4 +- drivers/xlsx/ingest.go | 2 +- drivers/xlsx/xlsx_test.go | 33 ++- libsq/ast/ast_test.go | 4 +- libsq/ast/parser_test.go | 4 +- libsq/ast/range_test.go | 4 +- libsq/ast/selector_test.go | 4 +- libsq/core/ioz/ioz.go | 1 - libsq/core/kind/internal_test.go | 12 +- libsq/core/kind/kind_test.go | 4 +- libsq/core/lg/devlog/devlog.go | 4 +- libsq/core/lg/devlog/tint/buffer.go | 4 +- libsq/core/loz/loz_test.go | 4 +- libsq/core/options/options_test.go | 6 +- libsq/core/progress/progress.go | 1 - libsq/core/record/record_test.go | 4 +- libsq/core/stringz/stringz_test.go | 20 +- libsq/core/timez/timez_test.go | 4 +- libsq/core/urlz/urlz_test.go | 16 +- libsq/driver/driver.go | 14 - libsq/driver/driver_test.go | 12 +- libsq/driver/sources.go | 39 ++- libsq/libsq.go | 13 +- libsq/libsq_test.go | 4 +- libsq/pipeline.go | 22 +- libsq/query_expr_test.go | 4 +- libsq/query_join_test.go | 4 +- libsq/query_no_src_test.go | 10 +- libsq/query_test.go | 12 +- libsq/source/cache.go | 19 +- libsq/source/detect.go | 219 ++++++++++++++ libsq/source/download.go | 46 +++ libsq/source/files.go | 273 ++---------------- libsq/source/files_test.go | 12 +- libsq/source/handle_test.go | 8 +- libsq/source/internal_test.go | 8 +- libsq/source/location_test.go | 4 +- libsq/source/source_test.go | 4 +- testh/sakila/sakila_test.go | 4 +- testh/testh.go | 11 +- testh/testh_test.go | 4 +- testh/{tutil => tu}/tutil.go | 20 +- testh/{tutil => tu}/tutil_test.go | 2 +- 68 files changed, 568 insertions(+), 522 deletions(-) create mode 100644 libsq/source/detect.go create mode 100644 libsq/source/download.go rename testh/{tutil => tu}/tutil.go (94%) rename testh/{tutil => tu}/tutil_test.go (99%) diff --git a/cli/cli_test.go b/cli/cli_test.go index 6f7c4a248..543471a17 100644 --- a/cli/cli_test.go +++ b/cli/cli_test.go @@ -23,7 +23,7 @@ import ( "github.com/neilotoole/sq/testh/fixt" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestSmoke(t *testing.T) { @@ -181,7 +181,7 @@ func TestExprNoSource(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { tr := testrun.New(context.Background(), t, nil).Hush() err := tr.Exec("--csv", "--no-header", tc.in) require.NoError(t, err) diff --git a/cli/cmd_add_test.go b/cli/cmd_add_test.go index d959b83d0..ce2998dd8 100644 --- a/cli/cmd_add_test.go +++ b/cli/cmd_add_test.go @@ -19,7 +19,7 @@ import ( "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestCmdAdd(t *testing.T) { @@ -190,7 +190,7 @@ func TestCmdAdd(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.wantHandle, tc.loc, tc.driver), func(t *testing.T) { + t.Run(tu.Name(i, tc.wantHandle, tc.loc, tc.driver), func(t *testing.T) { args := []string{"add", tc.loc} if tc.handle != "" { args = append(args, "--handle="+tc.handle) diff --git a/cli/cmd_cache.go b/cli/cmd_cache.go index b3553531f..76a22b6cb 100644 --- a/cli/cmd_cache.go +++ b/cli/cmd_cache.go @@ -4,6 +4,10 @@ import ( "os" "path/filepath" + "github.com/spf13/cobra" + + "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" @@ -11,10 +15,6 @@ import ( "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" - "github.com/spf13/cobra" - - "github.com/neilotoole/sq/cli/flag" - "github.com/neilotoole/sq/cli/run" ) func newCacheCmd() *cobra.Command { @@ -64,7 +64,7 @@ func newCacheLocationCmd() *cobra.Command { } func execCacheLocation(cmd *cobra.Command, _ []string) error { - dir := source.CacheDirPath() + dir := source.DefaultCacheDir() ru := run.FromContext(cmd.Context()) return ru.Writers.Config.CacheLocation(dir) } @@ -87,7 +87,7 @@ func newCacheInfoCmd() *cobra.Command { } func execCacheInfo(cmd *cobra.Command, _ []string) error { - dir := source.CacheDirPath() + dir := source.DefaultCacheDir() ru := run.FromContext(cmd.Context()) size, err := ioz.DirSize(dir) if err != nil { @@ -115,7 +115,7 @@ func newCacheClearCmd() *cobra.Command { func execCacheClear(cmd *cobra.Command, _ []string) error { log := lg.FromContext(cmd.Context()) - cacheDir := source.CacheDirPath() + cacheDir := source.DefaultCacheDir() if !ioz.DirExists(cacheDir) { return nil } @@ -125,7 +125,7 @@ func execCacheClear(cmd *cobra.Command, _ []string) error { // help with the situation where another sq instance has an open pid // lock in the cache dir. - tmpDir := source.TempDirPath() + tmpDir := source.DefaultTempDir() if err := ioz.RequireDir(tmpDir); err != nil { return errz.Wrap(err, "cache clear") } @@ -166,7 +166,7 @@ func newCacheTreeCmd() *cobra.Command { func execCacheTree(cmd *cobra.Command, _ []string) error { ru := run.FromContext(cmd.Context()) - cacheDir := source.CacheDirPath() + cacheDir := source.DefaultCacheDir() if !ioz.DirExists(cacheDir) { return nil } diff --git a/cli/cmd_inspect_test.go b/cli/cmd_inspect_test.go index 8866d8521..dbbeea2d8 100644 --- a/cli/cmd_inspect_test.go +++ b/cli/cmd_inspect_test.go @@ -23,13 +23,13 @@ import ( "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) // TestCmdInspect_json_yaml tests "sq inspect" for // the JSON and YAML formats. func TestCmdInspect_json_yaml(t *testing.T) { - tutil.SkipShort(t, true) + tu.SkipShort(t, true) possibleTbls := append(sakila.AllTbls(), source.MonotableName) testCases := []struct { @@ -60,7 +60,7 @@ func TestCmdInspect_json_yaml(t *testing.T) { tc := tc t.Run(tc.handle, func(t *testing.T) { - tutil.SkipWindowsIf(t, tc.handle == sakila.XLSX, "XLSX too slow on windows workflow") + tu.SkipWindowsIf(t, tc.handle == sakila.XLSX, "XLSX too slow on windows workflow") th := testh.New(t) src := th.Source(tc.handle) @@ -90,7 +90,7 @@ func TestCmdInspect_json_yaml(t *testing.T) { for _, tblName := range gotTableNames { tblName := tblName t.Run(tblName, func(t *testing.T) { - tutil.SkipShort(t, true) + tu.SkipShort(t, true) tr2 := testrun.New(th.Context, t, tr) err := tr2.Exec("inspect", "."+tblName, fmt.Sprintf("--%s", tf.format)) require.NoError(t, err) @@ -172,7 +172,7 @@ func TestCmdInspect_text(t *testing.T) { //nolint:tparallel t.Run(tc.handle, func(t *testing.T) { t.Parallel() - tutil.SkipWindowsIf(t, tc.handle == sakila.XLSX, "XLSX too slow on windows workflow") + tu.SkipWindowsIf(t, tc.handle == sakila.XLSX, "XLSX too slow on windows workflow") th := testh.New(t) src := th.Source(tc.handle) @@ -198,7 +198,7 @@ func TestCmdInspect_text(t *testing.T) { //nolint:tparallel for _, tblName := range tc.wantTbls { tblName := tblName t.Run(tblName, func(t *testing.T) { - tutil.SkipShort(t, true) + tu.SkipShort(t, true) t.Logf("Test: sq inspect .tbl") tr2 := testrun.New(th.Context, t, tr) err := tr2.Exec("inspect", "."+tblName, fmt.Sprintf("--%s", format.Text)) @@ -302,7 +302,7 @@ func TestCmdInspect_stdin(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tutil.Name(tc.fpath), func(t *testing.T) { + t.Run(tu.Name(tc.fpath), func(t *testing.T) { ctx := context.Background() f, err := os.Open(tc.fpath) // No need to close f require.NoError(t, err) diff --git a/cli/cmd_mv_test.go b/cli/cmd_mv_test.go index fc08c54a4..3d6ec0536 100644 --- a/cli/cmd_mv_test.go +++ b/cli/cmd_mv_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/cli" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestLastHandlePart(t *testing.T) { @@ -21,7 +21,7 @@ func TestLastHandlePart(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { got := cli.LastHandlePart(tc.in) require.Equal(t, tc.want, got) }) diff --git a/cli/cmd_sql_test.go b/cli/cmd_sql_test.go index 91b21e0e4..3cff93231 100644 --- a/cli/cmd_sql_test.go +++ b/cli/cmd_sql_test.go @@ -19,7 +19,7 @@ import ( "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/testsrc" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) // TestCmdSQL_Insert tests "sq sql QUERY --insert=dest.tbl". @@ -28,7 +28,7 @@ func TestCmdSQL_Insert(t *testing.T) { origin := origin t.Run("origin_"+origin, func(t *testing.T) { - tutil.SkipShort(t, origin == sakila.XLSX) + tu.SkipShort(t, origin == sakila.XLSX) for _, dest := range sakila.SQLLatest() { dest := dest @@ -161,7 +161,7 @@ func TestCmdSQL_StdinQuery(t *testing.T) { for i, tc := range testCases { tc := tc - name := tutil.Name(i, filepath.Base(filepath.Dir(tc.fpath)), filepath.Base(tc.fpath)) + name := tu.Name(i, filepath.Base(filepath.Dir(tc.fpath)), filepath.Base(tc.fpath)) t.Run(name, func(t *testing.T) { t.Parallel() diff --git a/cli/cobraz/cobraz_test.go b/cli/cobraz/cobraz_test.go index c18d6859e..97e94f05b 100644 --- a/cli/cobraz/cobraz_test.go +++ b/cli/cobraz/cobraz_test.go @@ -6,7 +6,7 @@ import ( "github.com/spf13/cobra" "github.com/stretchr/testify/require" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestExtractDirectives(t *testing.T) { @@ -33,7 +33,7 @@ func TestExtractDirectives(t *testing.T) { } for i, tc := range testCases { - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { gotDirectives := ExtractDirectives(tc.in) require.Equal(t, tc.want, gotDirectives) gotStrings := MarshalDirective(tc.in) diff --git a/cli/complete_location_test.go b/cli/complete_location_test.go index 734b3f064..4c510e571 100644 --- a/cli/complete_location_test.go +++ b/cli/complete_location_test.go @@ -22,7 +22,7 @@ import ( "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/testh" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) var locSchemes = []string{ @@ -35,9 +35,9 @@ var locSchemes = []string{ const stdDirective = cobra.ShellCompDirectiveNoSpace | cobra.ShellCompDirectiveKeepOrder func TestCompleteAddLocation_Postgres(t *testing.T) { - tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + tu.SkipWindows(t, "Shell completion not fully implemented for windows") - wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + wd := tu.Chdir(t, filepath.Join("testdata", "add_location")) t.Logf("Working dir: %s", wd) testCases := []struct { @@ -261,7 +261,7 @@ func TestCompleteAddLocation_Postgres(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + t.Run(tu.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { args := append([]string{"add"}, tc.args...) got := testComplete(t, nil, args...) assert.Equal(t, tc.wantResult, got.result, got.directives) @@ -271,9 +271,9 @@ func TestCompleteAddLocation_Postgres(t *testing.T) { } func TestCompleteAddLocation_SQLServer(t *testing.T) { - tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + tu.SkipWindows(t, "Shell completion not fully implemented for windows") - wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + wd := tu.Chdir(t, filepath.Join("testdata", "add_location")) t.Logf("Working dir: %s", wd) testCases := []struct { @@ -391,7 +391,7 @@ func TestCompleteAddLocation_SQLServer(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + t.Run(tu.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { args := append([]string{"add"}, tc.args...) got := testComplete(t, nil, args...) assert.Equal(t, tc.wantResult, got.result, got.directives) @@ -401,9 +401,9 @@ func TestCompleteAddLocation_SQLServer(t *testing.T) { } func TestCompleteAddLocation_MySQL(t *testing.T) { - tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + tu.SkipWindows(t, "Shell completion not fully implemented for windows") - wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + wd := tu.Chdir(t, filepath.Join("testdata", "add_location")) t.Logf("Working dir: %s", wd) testCases := []struct { @@ -621,7 +621,7 @@ func TestCompleteAddLocation_MySQL(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + t.Run(tu.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { args := append([]string{"add"}, tc.args...) got := testComplete(t, nil, args...) assert.Equal(t, tc.wantResult, got.result, got.directives) @@ -631,9 +631,9 @@ func TestCompleteAddLocation_MySQL(t *testing.T) { } func TestCompleteAddLocation_SQLite3(t *testing.T) { - tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + tu.SkipWindows(t, "Shell completion not fully implemented for windows") - wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + wd := tu.Chdir(t, filepath.Join("testdata", "add_location")) t.Logf("Working dir: %s", wd) testCases := []struct { @@ -782,7 +782,7 @@ func TestCompleteAddLocation_SQLite3(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + t.Run(tu.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { args := append([]string{"add"}, tc.args...) got := testComplete(t, nil, args...) assert.Equal(t, tc.wantResult, got.result, got.directives) @@ -792,8 +792,8 @@ func TestCompleteAddLocation_SQLite3(t *testing.T) { } func TestCompleteAddLocation_History_Postgres(t *testing.T) { - tutil.SkipWindows(t, "Shell completion not fully implemented for windows") - wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + tu.SkipWindows(t, "Shell completion not fully implemented for windows") + wd := tu.Chdir(t, filepath.Join("testdata", "add_location")) t.Logf("Working dir: %s", wd) th := testh.New(t) @@ -919,7 +919,7 @@ func TestCompleteAddLocation_History_Postgres(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + t.Run(tu.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { args := append([]string{"add"}, tc.args...) got := testComplete(t, tr, args...) assert.Equal(t, tc.wantResult, got.result, got.directives) @@ -929,8 +929,8 @@ func TestCompleteAddLocation_History_Postgres(t *testing.T) { } func TestCompleteAddLocation_History_SQLServer(t *testing.T) { - tutil.SkipWindows(t, "Shell completion not fully implemented for windows") - wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + tu.SkipWindows(t, "Shell completion not fully implemented for windows") + wd := tu.Chdir(t, filepath.Join("testdata", "add_location")) t.Logf("Working dir: %s", wd) th := testh.New(t) @@ -1120,7 +1120,7 @@ func TestCompleteAddLocation_History_SQLServer(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + t.Run(tu.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { args := append([]string{"add"}, tc.args...) got := testComplete(t, tr, args...) assert.Equal(t, tc.wantResult, got.result, got.directives) @@ -1130,8 +1130,8 @@ func TestCompleteAddLocation_History_SQLServer(t *testing.T) { } func TestCompleteAddLocation_History_SQLite3(t *testing.T) { - tutil.SkipWindows(t, "Shell completion not fully implemented for windows") - wd := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + tu.SkipWindows(t, "Shell completion not fully implemented for windows") + wd := tu.Chdir(t, filepath.Join("testdata", "add_location")) t.Logf("Working dir: %s", wd) src3Loc := "sqlite3://" + wd + "/my.db?cache=FAST" @@ -1250,7 +1250,7 @@ func TestCompleteAddLocation_History_SQLite3(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { + t.Run(tu.Name(i, strings.Join(tc.args, "_")), func(t *testing.T) { args := append([]string{"add"}, tc.args...) got := testComplete(t, tr, args...) assert.Equal(t, tc.wantResult, got.result, got.directives) @@ -1293,7 +1293,7 @@ func TestParseLoc_stage(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.loc), func(t *testing.T) { + t.Run(tu.Name(i, tc.loc), func(t *testing.T) { th := testh.New(t) ru := th.Run() @@ -1305,9 +1305,9 @@ func TestParseLoc_stage(t *testing.T) { } func TestDoCompleteAddLocationFile(t *testing.T) { - tutil.SkipWindows(t, "Shell completion not fully implemented for windows") + tu.SkipWindows(t, "Shell completion not fully implemented for windows") - absDir := tutil.Chdir(t, filepath.Join("testdata", "add_location")) + absDir := tu.Chdir(t, filepath.Join("testdata", "add_location")) t.Logf("Working dir: %s", absDir) testCases := []struct { @@ -1336,7 +1336,7 @@ func TestDoCompleteAddLocationFile(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { ctx := lg.NewContext(context.Background(), slogt.New(t)) t.Logf("input: %s", tc.in) t.Logf("want: %s", tc.want) diff --git a/cli/complete_test.go b/cli/complete_test.go index a9ffb1a6a..ff0090303 100644 --- a/cli/complete_test.go +++ b/cli/complete_test.go @@ -18,7 +18,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) // testComplete is a helper for testing cobra completion. @@ -152,7 +152,7 @@ func TestCompleteFlagActiveSchema_query_cmds(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.handles, tc.withFlagActiveSrc, tc.arg), func(t *testing.T) { + t.Run(tu.Name(i, tc.handles, tc.withFlagActiveSrc, tc.arg), func(t *testing.T) { t.Parallel() th := testh.New(t) @@ -262,7 +262,7 @@ func TestCompleteFlagActiveSchema_inspect(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.handles, tc.withArgActiveSrc, tc.arg), func(t *testing.T) { + t.Run(tu.Name(i, tc.handles, tc.withArgActiveSrc, tc.arg), func(t *testing.T) { t.Parallel() th := testh.New(t) diff --git a/cli/config/yamlstore/internal_test.go b/cli/config/yamlstore/internal_test.go index df71d0539..b16ee0874 100644 --- a/cli/config/yamlstore/internal_test.go +++ b/cli/config/yamlstore/internal_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/cli/flag" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func Test_getConfigDirFromFlag(t *testing.T) { @@ -27,7 +27,7 @@ func Test_getConfigDirFromFlag(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, strings.Join(tc.in, " ")), func(t *testing.T) { + t.Run(tu.Name(i, strings.Join(tc.in, " ")), func(t *testing.T) { got, gotOK, gotErr := getConfigDirFromFlag(tc.in) if tc.wantErr { require.Error(t, gotErr) diff --git a/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go b/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go index d1ac4cb71..d4c168b56 100644 --- a/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go +++ b/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go @@ -24,7 +24,7 @@ import ( "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/testh" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestUpgrade(t *testing.T) { @@ -42,7 +42,7 @@ func TestUpgrade(t *testing.T) { testh.SetBuildVersion(t, nextVers) // The sq.yml file in cfgDir is on v0.33.0 - cfgDir := tutil.DirCopy(t, "testdata", true) + cfgDir := tu.DirCopy(t, "testdata", true) t.Setenv(config.EnvarConfig, cfgDir) cfgFilePath := filepath.Join(cfgDir, "sq.yml") diff --git a/cli/config/yamlstore/yamlstore_test.go b/cli/config/yamlstore/yamlstore_test.go index da86a9ab9..20cf96dc8 100644 --- a/cli/config/yamlstore/yamlstore_test.go +++ b/cli/config/yamlstore/yamlstore_test.go @@ -14,7 +14,7 @@ import ( "github.com/neilotoole/sq/cli/config/yamlstore" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/testh/proj" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestFileStore_Nil_Save(t *testing.T) { @@ -85,7 +85,7 @@ func TestFileStore_Load(t *testing.T) { for _, match := range good { match := match - t.Run(tutil.Name(match), func(t *testing.T) { + t.Run(tu.Name(match), func(t *testing.T) { fs.Path = match cfg, err := fs.Load(context.Background()) require.NoError(t, err, match) @@ -95,7 +95,7 @@ func TestFileStore_Load(t *testing.T) { for _, match := range bad { match := match - t.Run(tutil.Name(match), func(t *testing.T) { + t.Run(tu.Name(match), func(t *testing.T) { fs.Path = match cfg, err := fs.Load(context.Background()) t.Log(err) diff --git a/cli/output/adapter_test.go b/cli/output/adapter_test.go index b2ba9ce20..dcee9c3cb 100644 --- a/cli/output/adapter_test.go +++ b/cli/output/adapter_test.go @@ -14,7 +14,7 @@ import ( "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/testsrc" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) var _ libsq.RecordWriter = (*output.RecordWriterAdapter)(nil) @@ -127,7 +127,7 @@ func TestRecordWriterAdapter_FlushAfterDuration(t *testing.T) { testCases := []struct { flushAfter time.Duration wantFlushed int - assertFn tutil.AssertCompareFunc + assertFn tu.AssertCompareFunc }{ {flushAfter: -1, wantFlushed: 0, assertFn: require.Equal}, {flushAfter: 0, wantFlushed: 0, assertFn: require.Equal}, diff --git a/cli/output/tablew/configwriter.go b/cli/output/tablew/configwriter.go index cc432e252..6dd7ae3bd 100644 --- a/cli/output/tablew/configwriter.go +++ b/cli/output/tablew/configwriter.go @@ -4,14 +4,13 @@ import ( "fmt" "io" - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/stringz" - "github.com/fatih/color" "github.com/samber/lo" "github.com/neilotoole/sq/cli/output" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/stringz" ) var _ output.ConfigWriter = (*configWriter)(nil) diff --git a/cli/output/xlsxw/xlsxw_test.go b/cli/output/xlsxw/xlsxw_test.go index 3ccadba69..ac96ac5d1 100644 --- a/cli/output/xlsxw/xlsxw_test.go +++ b/cli/output/xlsxw/xlsxw_test.go @@ -21,7 +21,7 @@ import ( "github.com/neilotoole/sq/testh/fixt" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/testsrc" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestRecordWriter(t *testing.T) { @@ -75,7 +75,7 @@ func TestRecordWriter(t *testing.T) { require.NoError(t, w.WriteRecords(recs)) require.NoError(t, w.Close()) - _ = tutil.WriteTemp(t, fmt.Sprintf("*.%s.test.xlsx", tc.name), buf.Bytes(), false) + _ = tu.WriteTemp(t, fmt.Sprintf("*.%s.test.xlsx", tc.name), buf.Bytes(), false) want, err := os.ReadFile(tc.fixtPath) require.NoError(t, err) @@ -170,7 +170,7 @@ func TestOptDatetimeFormats(t *testing.T) { tr = testrun.New(th.Context, t, tr) require.NoError(t, tr.Exec("sql", "--xlsx", query)) - fpath := tutil.WriteTemp(t, "*.xlsx", tr.Out.Bytes(), true) + fpath := tu.WriteTemp(t, "*.xlsx", tr.Out.Bytes(), true) gotDatetime := readCellValue(t, fpath, source.MonotableName, "A2") gotDate := readCellValue(t, fpath, source.MonotableName, "B2") diff --git a/cli/run.go b/cli/run.go index 1eaf8e3b1..033d5a816 100644 --- a/cli/run.go +++ b/cli/run.go @@ -140,7 +140,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { var err error if ru.Files == nil { - ru.Files, err = source.NewFiles(ctx) + ru.Files, err = source.NewFiles(ctx, source.DefaultTempDir(), source.DefaultCacheDir()) if err != nil { lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run) return err diff --git a/cli/run/run.go b/cli/run/run.go index 5b5b01812..eed1fcc7c 100644 --- a/cli/run/run.go +++ b/cli/run/run.go @@ -99,10 +99,8 @@ func (ru *Run) Close() error { // NewQueryContext returns a *libsq.QueryContext constructed from ru. func NewQueryContext(ru *Run, args map[string]string) *libsq.QueryContext { return &libsq.QueryContext{ - Collection: ru.Config.Collection, - PoolOpener: ru.Sources, - JoinPoolOpener: ru.Sources, - ScratchPoolOpener: ru.Sources, - Args: args, + Collection: ru.Config.Collection, + Sources: ru.Sources, + Args: args, } } diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index 61c3b9dc2..924b9e538 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -342,7 +342,7 @@ func TestDatetime(t *testing.T) { // TestIngestLargeCSV generates a large CSV file. // At count = 5000000, the generated file is ~500MB. func TestGenerateLargeCSV(t *testing.T) { - // t.Skip() + t.Skip() const count = 5000000 // Generates ~500MB file start := time.Now() header := []string{ diff --git a/drivers/csv/detect_header_test.go b/drivers/csv/detect_header_test.go index 8f088c66a..a7a7a98bc 100644 --- a/drivers/csv/detect_header_test.go +++ b/drivers/csv/detect_header_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func Test_detectHeaderRow(t *testing.T) { @@ -24,7 +24,7 @@ func Test_detectHeaderRow(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.fp), func(t *testing.T) { + t.Run(tu.Name(i, tc.fp), func(t *testing.T) { recs := readAllRecs(t, tc.comma, tc.fp) gotHasHeader, err := detectHeaderRow(recs) diff --git a/drivers/csv/internal_test.go b/drivers/csv/internal_test.go index 4026216f1..0004589dc 100644 --- a/drivers/csv/internal_test.go +++ b/drivers/csv/internal_test.go @@ -13,7 +13,7 @@ import ( "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/stringz" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func Test_isCSV(t *testing.T) { @@ -86,7 +86,7 @@ func Test_detectColKinds(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.name), func(t *testing.T) { + t.Run(tu.Name(i, tc.name), func(t *testing.T) { gotKinds, _, gotErr := detectColKinds(tc.recs) if tc.wantErr { require.Error(t, gotErr) diff --git a/drivers/json/ingest_test.go b/drivers/json/ingest_test.go index 20307778f..42348631b 100644 --- a/drivers/json/ingest_test.go +++ b/drivers/json/ingest_test.go @@ -17,7 +17,7 @@ import ( "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/testsrc" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestImportJSONL_Flat(t *testing.T) { @@ -74,7 +74,7 @@ func TestImportJSONL_Flat(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.fpath, tc.input), func(t *testing.T) { + t.Run(tu.Name(i, tc.fpath, tc.input), func(t *testing.T) { openFn := func(ctx context.Context) (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(tc.input)), nil } @@ -187,7 +187,7 @@ func TestScanObjectsInArray(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { r := bytes.NewReader([]byte(tc.in)) gotObjs, gotChunks, err := json.ScanObjectsInArray(r) if tc.wantErr { @@ -219,7 +219,7 @@ func TestScanObjectsInArray_Files(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tutil.Name(tc.fname), func(t *testing.T) { + t.Run(tu.Name(tc.fname), func(t *testing.T) { f, err := os.Open(tc.fname) require.NoError(t, err) defer f.Close() @@ -260,7 +260,7 @@ func TestColumnOrderFlat(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { require.True(t, stdj.Valid([]byte(tc.in))) gotCols, err := json.ColumnOrderFlat([]byte(tc.in)) diff --git a/drivers/json/json_test.go b/drivers/json/json_test.go index 7814bb742..da4a74eb6 100644 --- a/drivers/json/json_test.go +++ b/drivers/json/json_test.go @@ -15,7 +15,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestDriverDetectorFuncs(t *testing.T) { @@ -92,7 +92,7 @@ func TestDriverDetectorFuncs(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tutil.Name(tc.fn, tc.fname), func(t *testing.T) { + t.Run(tu.Name(tc.fn, tc.fname), func(t *testing.T) { openFn := func(ctx context.Context) (io.ReadCloser, error) { return os.Open(filepath.Join("testdata", tc.fname)) } detectFn := detectFns[tc.fn] diff --git a/drivers/mysql/internal_test.go b/drivers/mysql/internal_test.go index 20e56eb8c..b01de1a44 100644 --- a/drivers/mysql/internal_test.go +++ b/drivers/mysql/internal_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/source" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) // Export for testing. @@ -98,7 +98,7 @@ func TestDSNFromLocation(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tutil.Name(tc.loc, tc.parseTime), func(t *testing.T) { + t.Run(tu.Name(tc.loc, tc.parseTime), func(t *testing.T) { src := &source.Source{ Handle: "@testhandle", Type: Type, diff --git a/drivers/sqlite3/extension_test.go b/drivers/sqlite3/extension_test.go index 785ee3d1a..6b56f4b18 100644 --- a/drivers/sqlite3/extension_test.go +++ b/drivers/sqlite3/extension_test.go @@ -13,7 +13,7 @@ import ( "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestExtension_fts5(t *testing.T) { @@ -23,7 +23,7 @@ func TestExtension_fts5(t *testing.T) { src := th.Add(&source.Source{ Handle: "@fts", Type: sqlite3.Type, - Location: "sqlite3://" + tutil.MustAbsFilepath("testdata", "sakila_fts5.db"), + Location: "sqlite3://" + tu.MustAbsFilepath("testdata", "sakila_fts5.db"), }) srcMeta, err := th.SourceMetadata(src) diff --git a/drivers/sqlite3/internal/sqlparser/sqlparser_test.go b/drivers/sqlite3/internal/sqlparser/sqlparser_test.go index 6ffacd09a..fa7cd4660 100644 --- a/drivers/sqlite3/internal/sqlparser/sqlparser_test.go +++ b/drivers/sqlite3/internal/sqlparser/sqlparser_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/drivers/sqlite3/internal/sqlparser" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestExtractTableNameFromCreateTableStmt(t *testing.T) { @@ -55,7 +55,7 @@ func TestExtractTableNameFromCreateTableStmt(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { schema, table, err := sqlparser.ExtractTableIdentFromCreateTableStmt(tc.in, tc.unescape) if tc.wantErr { require.Error(t, err) diff --git a/drivers/sqlite3/sqlite3_test.go b/drivers/sqlite3/sqlite3_test.go index f81a7f2b5..4195ced47 100644 --- a/drivers/sqlite3/sqlite3_test.go +++ b/drivers/sqlite3/sqlite3_test.go @@ -18,7 +18,7 @@ import ( "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/fixt" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestSmoke(t *testing.T) { @@ -290,7 +290,7 @@ func TestMungeLocation(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { if tc.onlyForOS != "" && tc.onlyForOS != runtime.GOOS { t.Skipf("Skipping because this test is only for OS {%s}, but have {%s}", tc.onlyForOS, runtime.GOOS) diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 8583ddbac..83143c9d9 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -233,7 +233,7 @@ func ingestSheetToTable(ctx context.Context, destPool driver.Pool, sheetTbl *she close(bi.RecordCh) return err } - + select { case <-ctx.Done(): close(bi.RecordCh) diff --git a/drivers/xlsx/xlsx_test.go b/drivers/xlsx/xlsx_test.go index af4088158..5c0d393c5 100644 --- a/drivers/xlsx/xlsx_test.go +++ b/drivers/xlsx/xlsx_test.go @@ -26,7 +26,7 @@ import ( "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) var sakilaSheets = []string{ @@ -50,8 +50,8 @@ var sakilaSheets = []string{ func TestSakilaInspectSource(t *testing.T) { t.Parallel() - tutil.SkipWindows(t, "Skipping because of slow workflow perf on windows") - tutil.SkipShort(t, true) + tu.SkipWindows(t, "Skipping because of slow workflow perf on windows") + tu.SkipShort(t, true) th := testh.New(t, testh.OptLongOpen()) src := th.Source(sakila.XLSX) @@ -64,8 +64,8 @@ func TestSakilaInspectSource(t *testing.T) { func TestSakilaInspectSheets(t *testing.T) { t.Parallel() - tutil.SkipWindows(t, "Skipping because of slow workflow perf on windows") - tutil.SkipShort(t, true) + tu.SkipWindows(t, "Skipping because of slow workflow perf on windows") + tu.SkipShort(t, true) for _, sheet := range sakilaSheets { sheet := sheet @@ -84,8 +84,8 @@ func TestSakilaInspectSheets(t *testing.T) { } func BenchmarkInspectSheets(b *testing.B) { - tutil.SkipWindows(b, "Skipping because of slow workflow perf on windows") - tutil.SkipShort(b, true) + tu.SkipWindows(b, "Skipping because of slow workflow perf on windows") + tu.SkipShort(b, true) for _, sheet := range sakilaSheets { sheet := sheet @@ -108,8 +108,8 @@ func BenchmarkInspectSheets(b *testing.B) { func TestSakila_query_cmd(t *testing.T) { t.Parallel() - tutil.SkipWindows(t, "Skipping because of slow workflow perf on windows") - tutil.SkipShort(t, true) + tu.SkipWindows(t, "Skipping because of slow workflow perf on windows") + tu.SkipShort(t, true) for _, sheet := range sakilaSheets { sheet := sheet @@ -130,8 +130,8 @@ func TestSakila_query_cmd(t *testing.T) { func TestOpenFileFormats(t *testing.T) { t.Parallel() - tutil.SkipWindows(t, "Skipping because of slow workflow perf on windows") - tutil.SkipShort(t, true) + tu.SkipWindows(t, "Skipping because of slow workflow perf on windows") + tu.SkipShort(t, true) testCases := []struct { filename string @@ -163,14 +163,15 @@ func TestOpenFileFormats(t *testing.T) { }) pool, err := th.Sources().Open(th.Context, src) - require.NoError(t, err) - db, err := pool.DB(th.Context) if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) - require.NoError(t, db.PingContext(th.Context)) + db, err := pool.DB(th.Context) + require.NoError(t, err) + err = db.PingContext(th.Context) + require.NoError(t, err) sink, err := th.QuerySQL(src, nil, "SELECT * FROM actor") @@ -189,8 +190,8 @@ func TestOpenFileFormats(t *testing.T) { func TestSakila_query(t *testing.T) { t.Parallel() - tutil.SkipWindows(t, "Skipping because of slow workflow perf on windows") - tutil.SkipShort(t, true) + tu.SkipWindows(t, "Skipping because of slow workflow perf on windows") + tu.SkipShort(t, true) testCases := []struct { sheet string diff --git a/libsq/ast/ast_test.go b/libsq/ast/ast_test.go index 3d692c6d1..e58da1f4b 100644 --- a/libsq/ast/ast_test.go +++ b/libsq/ast/ast_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/ast" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestParseCatalogSchema(t *testing.T) { @@ -26,7 +26,7 @@ func TestParseCatalogSchema(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { gotCatalog, gotSchema, gotErr := ast.ParseCatalogSchema(tc.in) if tc.wantErr { require.Error(t, gotErr) diff --git a/libsq/ast/parser_test.go b/libsq/ast/parser_test.go index f0110464f..86f3af84e 100644 --- a/libsq/ast/parser_test.go +++ b/libsq/ast/parser_test.go @@ -9,7 +9,7 @@ import ( "github.com/neilotoole/slogt" "github.com/neilotoole/sq/libsq/ast/internal/slq" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) // getSLQParser returns a parser for the given SQL input. @@ -81,7 +81,7 @@ func TestParseBuild(t *testing.T) { } for i, tc := range testCases { - t.Run(tutil.Name(i, tc.name), func(t *testing.T) { + t.Run(tu.Name(i, tc.name), func(t *testing.T) { t.Logf(tc.in) log := slogt.New(t) diff --git a/libsq/ast/range_test.go b/libsq/ast/range_test.go index 8aad8685a..5ab7496ed 100644 --- a/libsq/ast/range_test.go +++ b/libsq/ast/range_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) // TestRowRange tests the row range mechanism. @@ -33,7 +33,7 @@ func TestRowRange(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { ast := mustParse(t, tc.in) insp := NewInspector(ast) nodes := insp.FindNodes(typeRowRangeNode) diff --git a/libsq/ast/selector_test.go b/libsq/ast/selector_test.go index c0c334cde..a98e5fc47 100644 --- a/libsq/ast/selector_test.go +++ b/libsq/ast/selector_test.go @@ -7,7 +7,7 @@ import ( "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestColumnAlias(t *testing.T) { @@ -28,7 +28,7 @@ func TestColumnAlias(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tutil.Name(tc.in), func(t *testing.T) { + t.Run(tu.Name(tc.in), func(t *testing.T) { t.Parallel() log := slogt.New(t) diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 5f3754a9f..21cf50cbe 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -16,7 +16,6 @@ import ( "github.com/a8m/tree" "github.com/a8m/tree/ostree" - yaml "github.com/goccy/go-yaml" "github.com/neilotoole/sq/libsq/core/errz" diff --git a/libsq/core/kind/internal_test.go b/libsq/core/kind/internal_test.go index 1c8d0bbcb..cc17b7856 100644 --- a/libsq/core/kind/internal_test.go +++ b/libsq/core/kind/internal_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestDetectKindDatetime(t *testing.T) { @@ -17,7 +17,7 @@ func TestDetectKindDatetime(t *testing.T) { for _, f := range datetimeFormats { f := f - t.Run(tutil.Name(f), func(t *testing.T) { + t.Run(tu.Name(f), func(t *testing.T) { s := tm.Format(f) ok, gotF := detectKindDatetime(s) @@ -51,7 +51,7 @@ func TestDetectKindDate(t *testing.T) { for i, input := range valid { input := input - t.Run(tutil.Name("valid", i, input), func(t *testing.T) { + t.Run(tu.Name("valid", i, input), func(t *testing.T) { t.Log(input) ok, gotF := detectKindDate(input) require.True(t, ok) @@ -72,7 +72,7 @@ func TestDetectKindDate(t *testing.T) { for i, input := range invalid { input := input - t.Run(tutil.Name("invalid", i, input), func(t *testing.T) { + t.Run(tu.Name("invalid", i, input), func(t *testing.T) { t.Log(input) ok, gotF := detectKindDate(input) require.False(t, ok) @@ -95,7 +95,7 @@ func TestDetectKindTime(t *testing.T) { for i, input := range valid { input := input - t.Run(tutil.Name("valid", i, input), func(t *testing.T) { + t.Run(tu.Name("valid", i, input), func(t *testing.T) { t.Log(input) ok, gotF := detectKindTime(input) require.True(t, ok) @@ -117,7 +117,7 @@ func TestDetectKindTime(t *testing.T) { for i, input := range invalid { input := input - t.Run(tutil.Name("invalid", i, input), func(t *testing.T) { + t.Run(tu.Name("invalid", i, input), func(t *testing.T) { t.Log(input) ok, gotF := detectKindTime(input) require.False(t, ok) diff --git a/libsq/core/kind/kind_test.go b/libsq/core/kind/kind_test.go index a397d668c..f8e60b02f 100644 --- a/libsq/core/kind/kind_test.go +++ b/libsq/core/kind/kind_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/core/kind" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestKind(t *testing.T) { @@ -148,7 +148,7 @@ func TestDetector(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { kd := kind.NewDetector() for _, val := range tc.in { diff --git a/libsq/core/lg/devlog/devlog.go b/libsq/core/lg/devlog/devlog.go index 810edea60..4bba3c060 100644 --- a/libsq/core/lg/devlog/devlog.go +++ b/libsq/core/lg/devlog/devlog.go @@ -1,3 +1,5 @@ +// Package devlog contains a custom slog.Handler for +// developer-friendly log output. package devlog import ( @@ -7,7 +9,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/devlog/tint" ) -const shortTimeFormat = `15:04:05.000000` +const shortTimeFormat = "15:04:05.000000" // NewHandler returns a developer-friendly slog.Handler that // logs to w. diff --git a/libsq/core/lg/devlog/tint/buffer.go b/libsq/core/lg/devlog/tint/buffer.go index 178aea7a8..884febed0 100644 --- a/libsq/core/lg/devlog/tint/buffer.go +++ b/libsq/core/lg/devlog/tint/buffer.go @@ -1,6 +1,8 @@ package tint -import "sync" +import ( + "sync" +) type buffer []byte diff --git a/libsq/core/loz/loz_test.go b/libsq/core/loz/loz_test.go index bd4500f29..18d595e50 100644 --- a/libsq/core/loz/loz_test.go +++ b/libsq/core/loz/loz_test.go @@ -7,7 +7,7 @@ import ( "github.com/neilotoole/sq/libsq/core/loz" "github.com/neilotoole/sq/libsq/core/stringz" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestAll(t *testing.T) { @@ -82,7 +82,7 @@ func TestAlignMatrixWidth(t *testing.T) { } for i, tc := range testCases { - t.Run(tutil.Name(i), func(t *testing.T) { + t.Run(tu.Name(i), func(t *testing.T) { loz.AlignMatrixWidth(tc.in, defaultVal) require.EqualValues(t, tc.want, tc.in) }) diff --git a/libsq/core/options/options_test.go b/libsq/core/options/options_test.go index 9dc3b3c2f..77fac82b5 100644 --- a/libsq/core/options/options_test.go +++ b/libsq/core/options/options_test.go @@ -16,7 +16,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) type config struct { @@ -67,7 +67,7 @@ func TestInt(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.key), func(t *testing.T) { + t.Run(tu.Name(i, tc.key), func(t *testing.T) { reg := &options.Registry{} opt := options.NewInt(tc.key, "", 0, tc.defaultVal, "", "") @@ -113,7 +113,7 @@ func TestBool(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.key), func(t *testing.T) { + t.Run(tu.Name(i, tc.key), func(t *testing.T) { reg := &options.Registry{} opt := options.NewBool(tc.key, "", false, 0, tc.defaultVal, "", "") diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 127315fde..85d0d1d6a 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -360,7 +360,6 @@ func (p *Progress) newBar(msg string, total int64, // the bar is complete, the caller should invoke [Bar.Stop]. All // methods are safe to call on a nil Bar. type Bar struct { - // bar is nil until barInitOnce.Do(initBar) is called bar *mpb.Bar // p is never nil diff --git a/libsq/core/record/record_test.go b/libsq/core/record/record_test.go index 89d2d2ef0..f73cebf9c 100644 --- a/libsq/core/record/record_test.go +++ b/libsq/core/record/record_test.go @@ -8,7 +8,7 @@ import ( "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/timez" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestEqual(t *testing.T) { @@ -53,7 +53,7 @@ func TestEqual(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.a, tc.b), func(t *testing.T) { + t.Run(tu.Name(i, tc.a, tc.b), func(t *testing.T) { _, err := record.Valid(tc.a) require.NoError(t, err) _, err = record.Valid(tc.b) diff --git a/libsq/core/stringz/stringz_test.go b/libsq/core/stringz/stringz_test.go index 7629764a9..8e9aabf5d 100644 --- a/libsq/core/stringz/stringz_test.go +++ b/libsq/core/stringz/stringz_test.go @@ -11,7 +11,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/core/stringz" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestGenerateAlphaColName(t *testing.T) { @@ -297,7 +297,7 @@ func TestLineCount(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { count := stringz.LineCount(strings.NewReader(tc.in), false) require.Equal(t, tc.withEmpty, count) count = stringz.LineCount(strings.NewReader(tc.in), true) @@ -340,7 +340,7 @@ func TestStripDoubleQuote(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { got := stringz.StripDoubleQuote(tc.in) require.Equal(t, tc.want, got) }) @@ -401,7 +401,7 @@ func TestValidIdent(t *testing.T) { } for _, tc := range testCases { tc := tc - t.Run(tutil.Name(tc.in), func(t *testing.T) { + t.Run(tu.Name(tc.in), func(t *testing.T) { gotErr := stringz.ValidIdent(tc.in) if tc.wantErr { require.Error(t, gotErr) @@ -429,7 +429,7 @@ func TestStrings(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { got := stringz.Strings(tc.in) require.Len(t, got, len(tc.in)) @@ -457,7 +457,7 @@ func TestStringsD(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { got := stringz.StringsD(tc.in) require.Len(t, got, len(tc.in)) @@ -517,7 +517,7 @@ func TestTemplate(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.tpl), func(t *testing.T) { + t.Run(tu.Name(i, tc.tpl), func(t *testing.T) { got, gotErr := stringz.ExecuteTemplate(t.Name(), tc.tpl, tc.data) t.Logf("\nTPL: %s\nGOT: %s\nERR: %v", tc.tpl, got, gotErr) if tc.wantErr { @@ -550,7 +550,7 @@ func TestShellEscape(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc), func(t *testing.T) { + t.Run(tu.Name(i, tc), func(t *testing.T) { got := stringz.ShellEscape(tc.in) require.Equal(t, tc.want, got) }) @@ -581,7 +581,7 @@ func TestTrimLenMiddle(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.input, tc.maxLen), func(t *testing.T) { + t.Run(tu.Name(i, tc.input, tc.maxLen), func(t *testing.T) { got := stringz.TrimLenMiddle(tc.input, tc.maxLen) require.True(t, len(got) <= tc.maxLen) require.Equal(t, tc.want, got) @@ -611,7 +611,7 @@ func TestDecimal(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in, tc.wantStr), func(t *testing.T) { + t.Run(tu.Name(i, tc.in, tc.wantStr), func(t *testing.T) { gotStr := stringz.FormatDecimal(tc.in) require.Equal(t, tc.wantStr, gotStr) gotPlaces := stringz.DecimalPlaces(tc.in) diff --git a/libsq/core/timez/timez_test.go b/libsq/core/timez/timez_test.go index f7e0caa89..5bb955bbe 100644 --- a/libsq/core/timez/timez_test.go +++ b/libsq/core/timez/timez_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/core/timez" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) var ( @@ -53,7 +53,7 @@ func TestParseDateOrTimestampUTC(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { tm, err := timez.ParseDateOrTimestampUTC(tc.in) if tc.wantErr { require.Error(t, err) diff --git a/libsq/core/urlz/urlz_test.go b/libsq/core/urlz/urlz_test.go index 394a566a5..6c89bd2e3 100644 --- a/libsq/core/urlz/urlz_test.go +++ b/libsq/core/urlz/urlz_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/core/urlz" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestQueryParamKeys(t *testing.T) { @@ -25,7 +25,7 @@ func TestQueryParamKeys(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.q), func(t *testing.T) { + t.Run(tu.Name(i, tc.q), func(t *testing.T) { got, gotErr := urlz.QueryParamKeys(tc.q) if tc.wantErr { require.Error(t, gotErr) @@ -52,7 +52,7 @@ func TestStripQuery(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc), func(t *testing.T) { + t.Run(tu.Name(i, tc), func(t *testing.T) { u, err := url.Parse(tc.in) require.NoError(t, err) got := urlz.StripQuery(*u) @@ -74,7 +74,7 @@ func TestStripUser(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc), func(t *testing.T) { + t.Run(tu.Name(i, tc), func(t *testing.T) { u, err := url.Parse(tc.in) require.NoError(t, err) got := urlz.StripUser(*u) @@ -94,7 +94,7 @@ func TestStripScheme(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc), func(t *testing.T) { + t.Run(tu.Name(i, tc), func(t *testing.T) { u, err := url.Parse(tc.in) require.NoError(t, err) got := urlz.StripScheme(*u) @@ -116,7 +116,7 @@ func TestStripSchemeAndUser(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc), func(t *testing.T) { + t.Run(tu.Name(i, tc), func(t *testing.T) { u, err := url.Parse(tc.in) require.NoError(t, err) got := urlz.StripSchemeAndUser(*u) @@ -145,7 +145,7 @@ func TestRenameQueryParamKey(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.q, tc.oldKey, tc.newKey), func(t *testing.T) { + t.Run(tu.Name(i, tc.q, tc.oldKey, tc.newKey), func(t *testing.T) { got := urlz.RenameQueryParamKey(tc.q, tc.oldKey, tc.newKey) require.Equal(t, tc.want, got) }) @@ -167,7 +167,7 @@ func TestURLStripQuery(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc), func(t *testing.T) { + t.Run(tu.Name(i, tc), func(t *testing.T) { u, err := url.Parse(tc.in) require.NoError(t, err) got := urlz.StripQuery(*u) diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index e47b8b64e..ce41d7384 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -161,20 +161,6 @@ type PoolOpener interface { Open(ctx context.Context, src *source.Source) (Pool, error) } -// JoinPoolOpener can open a join database. -type JoinPoolOpener interface { - // OpenJoin opens an appropriate Pool for use as - // a work DB for joining across sources. - OpenJoin(ctx context.Context, srcs ...*source.Source) (Pool, error) -} - -// ScratchPoolOpener opens a scratch database pool. A scratch database is -// a short-lived database used for ephemeral purposes. -type ScratchPoolOpener interface { - // OpenScratch returns a pool for scratch use. - OpenScratch(ctx context.Context, src *source.Source) (Pool, error) -} - // IngestOpener opens a pool for ingest use. type IngestOpener interface { // OpenIngest opens a pool for src by executing ingestFn, which is diff --git a/libsq/driver/driver_test.go b/libsq/driver/driver_test.go index 49511277c..bccf7db81 100644 --- a/libsq/driver/driver_test.go +++ b/libsq/driver/driver_test.go @@ -26,7 +26,7 @@ import ( "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/fixt" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestDriver_DropTable(t *testing.T) { @@ -154,7 +154,7 @@ func TestDriver_TableColumnTypes(t *testing.T) { //nolint:tparallel handle := handle t.Run(handle, func(t *testing.T) { - tutil.SkipShort(t, handle == sakila.XLSX) + tu.SkipShort(t, handle == sakila.XLSX) t.Parallel() th, src, drvr, _, db := testh.NewWith(t, handle) @@ -194,7 +194,7 @@ func TestSQLDriver_PrepareUpdateStmt(t *testing.T) { //nolint:tparallel handle := handle t.Run(handle, func(t *testing.T) { - tutil.SkipShort(t, handle == sakila.XLSX) + tu.SkipShort(t, handle == sakila.XLSX) t.Parallel() th, src, drvr, _, db := testh.NewWith(t, handle) @@ -239,7 +239,7 @@ func TestDriver_Ping(t *testing.T) { handle := handle t.Run(handle, func(t *testing.T) { - tutil.SkipShort(t, handle == sakila.XLSX) + tu.SkipShort(t, handle == sakila.XLSX) th := testh.New(t) src := th.Source(handle) @@ -260,7 +260,7 @@ func TestDriver_Open(t *testing.T) { handle := handle t.Run(handle, func(t *testing.T) { - tutil.SkipShort(t, handle == sakila.XLSX) + tu.SkipShort(t, handle == sakila.XLSX) t.Parallel() th := testh.New(t) @@ -746,7 +746,7 @@ func TestMungeColNames(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { ctx := options.NewContext(context.Background(), options.Options{}) got, err := driver.MungeResultColNames(ctx, tc.in) require.NoError(t, err) diff --git a/libsq/driver/sources.go b/libsq/driver/sources.go index 5f5fe8ba1..eb2190cae 100644 --- a/libsq/driver/sources.go +++ b/libsq/driver/sources.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "github.com/neilotoole/sq/libsq/source/drivertype" + "github.com/nightlyone/lockfile" "github.com/neilotoole/sq/libsq/core/cleanup" @@ -22,11 +24,7 @@ import ( "github.com/neilotoole/sq/libsq/source" ) -var ( - _ PoolOpener = (*Sources)(nil) - _ JoinPoolOpener = (*Sources)(nil) - _ ScratchPoolOpener = (*Sources)(nil) -) +var _ PoolOpener = (*Sources)(nil) // ScratchSrcFunc is a function that returns a scratch source. // The caller is responsible for invoking cleanFn. @@ -77,6 +75,29 @@ func (ss *Sources) Open(ctx context.Context, src *source.Source) (Pool, error) { return ss.doOpen(ctx, src) } +// DriverFor returns the driver for typ. +func (ss *Sources) DriverFor(typ drivertype.Type) (Driver, error) { + return ss.drvrs.DriverFor(typ) +} + +// IsSQLSource returns true if src's driver is a SQLDriver. +func (ss *Sources) IsSQLSource(src *source.Source) bool { + if src == nil { + return false + } + + drvr, err := ss.drvrs.DriverFor(src.Type) + if err != nil { + return false + } + + if _, ok := drvr.(SQLDriver); ok { + return true + } + + return false +} + func (ss *Sources) doOpen(ctx context.Context, src *source.Source) (Pool, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) key := src.Handle + "_" + src.Hash() @@ -109,8 +130,6 @@ func (ss *Sources) doOpen(ctx context.Context, src *source.Source) (Pool, error) // OpenScratch returns a scratch database instance. It is not // necessary for the caller to close the returned Pool as // its Close method will be invoked by d.Close. -// -// OpenScratch implements ScratchPoolOpener. func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Pool, error) { const msgCloseScratch = "Close scratch db" @@ -184,6 +203,7 @@ func (ss *Sources) openIngestNoCache(ctx context.Context, src *source.Source, lga.Elapsed, elapsed, lga.Err, err, ) lg.WarnIfCloseError(log, lgm.CloseDB, impl) + return nil, err } ss.log.Debug("Ingest completed", @@ -282,7 +302,7 @@ func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, // There is no guarantee that these files exist, or are accessible. // It's just the paths. func (ss *Sources) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checksums string, err error) { - if srcCacheDir, err = source.CacheDirFor(src); err != nil { + if srcCacheDir, err = ss.files.CacheDirFor(src); err != nil { return "", "", "", err } @@ -361,6 +381,7 @@ func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Pool, src.Handle, src.Type) } + // FIXME: Not too sure invoking files.Filepath here is the right approach? srcFilepath, err := ss.files.Filepath(ctx, src) if err != nil { return nil, false, err @@ -416,8 +437,6 @@ func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Pool, // location for the join to occur (to minimize copying of data for // the join etc.). Currently the implementation simply delegates // to OpenScratch. -// -// OpenJoin implements JoinPoolOpener. func (ss *Sources) OpenJoin(ctx context.Context, srcs ...*source.Source) (Pool, error) { var names []string for _, src := range srcs { diff --git a/libsq/libsq.go b/libsq/libsq.go index edc1a56c1..d901ee02d 100644 --- a/libsq/libsq.go +++ b/libsq/libsq.go @@ -12,10 +12,9 @@ package libsq import ( "context" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlz" @@ -28,14 +27,8 @@ type QueryContext struct { // Collection is the set of sources. Collection *source.Collection - // PoolOpener is used to open databases. - PoolOpener driver.PoolOpener - - // JoinPoolOpener is used to open the joindb. - JoinPoolOpener driver.JoinPoolOpener - - // ScratchPoolOpener is used to open the scratchdb. - ScratchPoolOpener driver.ScratchPoolOpener + // Sources bridges between source.Source and databases. + Sources *driver.Sources // Args defines variables that are substituted into the query. // May be nil or empty. diff --git a/libsq/libsq_test.go b/libsq/libsq_test.go index 5e7647e9f..508d4b753 100644 --- a/libsq/libsq_test.go +++ b/libsq/libsq_test.go @@ -13,7 +13,7 @@ import ( "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) // TestQuerySQL_Smoke is a smoke test of testh.QuerySQL. @@ -60,7 +60,7 @@ func TestQuerySQL_Smoke(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.handle, func(t *testing.T) { - tutil.SkipShort(t, tc.handle == sakila.XLSX) + tu.SkipShort(t, tc.handle == sakila.XLSX) t.Parallel() th := testh.New(t, testh.OptLongOpen()) diff --git a/libsq/pipeline.go b/libsq/pipeline.go index e072b07a4..c27c2ca9a 100644 --- a/libsq/pipeline.go +++ b/libsq/pipeline.go @@ -184,15 +184,19 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { ) if handle == "" { - if src = p.qc.Collection.Active(); src == nil { - log.Debug("No active source, will use scratchdb.") + src = p.qc.Collection.Active() + if src == nil || !p.qc.Sources.IsSQLSource(src) { + log.Debug("No active SQL source, will use scratchdb.") // REVISIT: ScratchPoolOpener needs a source, so we just make one up. ephemeralSrc := &source.Source{ Type: drivertype.None, - Handle: "@scratch" + stringz.Uniq8(), + Handle: "@scratch_" + stringz.Uniq8(), Ephemeral: true, } - p.targetPool, err = p.qc.ScratchPoolOpener.OpenScratch(ctx, ephemeralSrc) + + // FIXME: We really want to change the signature of OpenScratch to + // just need a name, not a source. + p.targetPool, err = p.qc.Sources.OpenScratch(ctx, ephemeralSrc) if err != nil { return err } @@ -211,7 +215,7 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { } // At this point, src is non-nil. - if p.targetPool, err = p.qc.PoolOpener.Open(ctx, src); err != nil { + if p.targetPool, err = p.qc.Sources.Open(ctx, src); err != nil { return err } @@ -243,7 +247,7 @@ func (p *pipeline) prepareFromTable(ctx context.Context, tblSel *ast.TblSelector return "", nil, err } - fromPool, err = p.qc.PoolOpener.Open(ctx, src) + fromPool, err = p.qc.Sources.Open(ctx, src) if err != nil { return "", nil, err } @@ -336,7 +340,7 @@ func (p *pipeline) joinSingleSource(ctx context.Context, jc *joinClause) (fromCl return "", nil, err } - fromPool, err = p.qc.PoolOpener.Open(ctx, src) + fromPool, err = p.qc.Sources.Open(ctx, src) if err != nil { return "", nil, err } @@ -374,7 +378,7 @@ func (p *pipeline) joinCrossSource(ctx context.Context, jc *joinClause) (fromCla } // Open the join db - joinPool, err := p.qc.JoinPoolOpener.OpenJoin(ctx, srcs...) + joinPool, err := p.qc.Sources.OpenJoin(ctx, srcs...) if err != nil { return "", nil, err } @@ -401,7 +405,7 @@ func (p *pipeline) joinCrossSource(ctx context.Context, jc *joinClause) (fromCla return "", nil, err } var db driver.Pool - if db, err = p.qc.PoolOpener.Open(ctx, src); err != nil { + if db, err = p.qc.Sources.Open(ctx, src); err != nil { return "", nil, err } diff --git a/libsq/query_expr_test.go b/libsq/query_expr_test.go index 0e256256f..717095d1c 100644 --- a/libsq/query_expr_test.go +++ b/libsq/query_expr_test.go @@ -7,7 +7,7 @@ import ( "github.com/neilotoole/sq/drivers/mysql" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) //nolint:exhaustive @@ -157,7 +157,7 @@ func TestQuery_expr_literal(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.name), func(t *testing.T) { + t.Run(tu.Name(i, tc.name), func(t *testing.T) { execQueryTestCase(t, tc) }) } diff --git a/libsq/query_join_test.go b/libsq/query_join_test.go index f0d8563a1..286e8d00c 100644 --- a/libsq/query_join_test.go +++ b/libsq/query_join_test.go @@ -14,7 +14,7 @@ import ( "github.com/neilotoole/sq/libsq/core/jointype" "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestQuery_join_args(t *testing.T) { @@ -384,7 +384,7 @@ func TestQuery_table_alias(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.name), func(t *testing.T) { + t.Run(tu.Name(i, tc.name), func(t *testing.T) { execQueryTestCase(t, tc) }) } diff --git a/libsq/query_no_src_test.go b/libsq/query_no_src_test.go index 03b76ae8e..d41783b52 100644 --- a/libsq/query_no_src_test.go +++ b/libsq/query_no_src_test.go @@ -8,7 +8,7 @@ import ( "github.com/neilotoole/sq/libsq" "github.com/neilotoole/sq/testh" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestQuery_no_source(t *testing.T) { @@ -26,17 +26,15 @@ func TestQuery_no_source(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { t.Logf("\nquery: %s\n want: %s", tc.in, tc.want) th := testh.New(t) coll := th.NewCollection() sources := th.Sources() qc := &libsq.QueryContext{ - Collection: coll, - PoolOpener: sources, - JoinPoolOpener: sources, - ScratchPoolOpener: sources, + Collection: coll, + Sources: sources, } gotSQL, gotErr := libsq.SLQ2SQL(th.Context, qc, tc.in) diff --git a/libsq/query_test.go b/libsq/query_test.go index 4d41ac4aa..593844ad0 100644 --- a/libsq/query_test.go +++ b/libsq/query_test.go @@ -13,7 +13,7 @@ import ( "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) // driverMap is a map of drivertype.Type to a string. @@ -109,7 +109,7 @@ func execQueryTestCase(t *testing.T, tc queryTestCase) { subTests := make([]queryTestCase, len(tc.repeatReplace)) for i := range tc.repeatReplace { subTests[i] = tc - subTests[i].name = tutil.Name(tc.repeatReplace[i]) + subTests[i].name = tu.Name(tc.repeatReplace[i]) if i == 0 { // No need for replacement on first item, it's the original. continue @@ -166,11 +166,9 @@ func doExecQueryTestCase(t *testing.T, tc queryTestCase) { sources := th.Sources() qc := &libsq.QueryContext{ - Collection: coll, - PoolOpener: sources, - JoinPoolOpener: sources, - ScratchPoolOpener: sources, - Args: tc.args, + Collection: coll, + Sources: sources, + Args: tc.args, } if tc.beforeRun != nil { diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 65a77d4c3..36a0f2fbe 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -11,7 +11,7 @@ import ( // CacheDirFor gets the cache dir for handle. It is not guaranteed // that the returned dir exists or is accessible. -func CacheDirFor(src *Source) (dir string, err error) { +func (fs *Files) CacheDirFor(src *Source) (dir string, err error) { handle := src.Handle if err = ValidHandle(handle); err != nil { return "", errz.Wrapf(err, "cache dir: invalid handle: %s", handle) @@ -24,7 +24,7 @@ func CacheDirFor(src *Source) (dir string, err error) { } dir = filepath.Join( - CacheDirPath(), + fs.cacheDir, "sources", filepath.Join(strings.Split(strings.TrimPrefix(handle, "@"), "/")...), src.Hash(), @@ -33,16 +33,16 @@ func CacheDirFor(src *Source) (dir string, err error) { return dir, nil } -// CacheDirPath returns the sq cache dir. This is generally -// in USER_CACHE_DIR/sq/cache, but could also be in TEMP_DIR/sq/cache +// DefaultCacheDir returns the sq cache dir. This is generally +// in USER_CACHE_DIR/*/sq, but could also be in TEMP_DIR/*/sq/cache // or similar. It is not guaranteed that the returned dir exists // or is accessible. -func CacheDirPath() (dir string) { +func DefaultCacheDir() (dir string) { var err error if dir, err = os.UserCacheDir(); err != nil { // Some systems may not have a user cache dir, so we fall back // to the system temp dir. - dir = filepath.Join(TempDirPath(), "cache") + dir = filepath.Join(DefaultTempDir(), "cache") return dir } @@ -50,9 +50,8 @@ func CacheDirPath() (dir string) { return dir } -// TempDirPath returns the sq temp dir. This is generally -// in TEMP_DIR/sq. It is not guaranteed that the returned dir exists -// or is accessible. -func TempDirPath() (dir string) { +// DefaultTempDir returns the default sq temp dir. It is not +// guaranteed that the returned dir exists or is accessible. +func DefaultTempDir() (dir string) { return filepath.Join(os.TempDir(), "sq") } diff --git a/libsq/source/detect.go b/libsq/source/detect.go new file mode 100644 index 000000000..d7b85ef36 --- /dev/null +++ b/libsq/source/detect.go @@ -0,0 +1,219 @@ +package source + +import ( + "context" + "io" + "mime" + "time" + + "github.com/h2non/filetype" + "github.com/h2non/filetype/matchers" + "golang.org/x/sync/errgroup" + + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/source/drivertype" +) + +// AddDriverDetectors adds driver type detectors. +func (fs *Files) AddDriverDetectors(detectFns ...DriverDetectFunc) { + fs.detectFns = append(fs.detectFns, detectFns...) +} + +// DetectStdinType detects the type of stdin as previously added +// by AddStdin. An error is returned if AddStdin was not +// first invoked. If the type cannot be detected, TypeNone and +// nil are returned. +func (fs *Files) DetectStdinType(ctx context.Context) (drivertype.Type, error) { + if !fs.fcache.Exists(StdinHandle) { + return drivertype.None, errz.New("must invoke Files.AddStdin before invoking DetectStdinType") + } + + typ, ok, err := fs.detectType(ctx, StdinHandle) + if err != nil { + return drivertype.None, err + } + + if !ok { + return drivertype.None, nil + } + + return typ, nil +} + +// DriverType returns the driver type of loc. +func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, error) { + log := lg.FromContext(ctx).With(lga.Loc, loc) + ploc, err := parseLoc(loc) + if err != nil { + return drivertype.None, err + } + + if ploc.typ != drivertype.None { + return ploc.typ, nil + } + + if ploc.ext != "" { + mtype := mime.TypeByExtension(ploc.ext) + if mtype == "" { + log.Debug("unknown mime type", lga.Type, mtype) + } else { + if typ, ok := typeFromMediaType(mtype); ok { + return typ, nil + } + log.Debug("unknown driver type for media type", lga.Type, mtype) + } + } + + // Fall back to the byte detectors + typ, ok, err := fs.detectType(ctx, loc) + if err != nil { + return drivertype.None, err + } + + if !ok { + return drivertype.None, errz.Errorf("unable to determine driver type: %s", loc) + } + + return typ, nil +} + +func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Type, ok bool, err error) { + if len(fs.detectFns) == 0 { + return drivertype.None, false, nil + } + log := lg.FromContext(ctx).With(lga.Loc, loc) + start := time.Now() + + type result struct { + typ drivertype.Type + score float32 + } + + resultCh := make(chan result, len(fs.detectFns)) + openFn := func(ctx context.Context) (io.ReadCloser, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + + return fs.newReader(ctx, loc) + } + + select { + case <-ctx.Done(): + return drivertype.None, false, ctx.Err() + default: + } + + g, gCtx := errgroup.WithContext(ctx) + + for _, detectFn := range fs.detectFns { + detectFn := detectFn + + g.Go(func() error { + select { + case <-gCtx.Done(): + return gCtx.Err() + default: + } + + gTyp, gScore, gErr := detectFn(gCtx, openFn) + if gErr != nil { + return gErr + } + + if gScore > 0 { + resultCh <- result{typ: gTyp, score: gScore} + } + return nil + }) + } + + // REVISIT: We shouldn't have to wait for all goroutines to complete. + // This logic could be refactored to return as soon as a single + // goroutine returns a score >= 1.0 (then cancelling the other detector + // goroutines). + + err = g.Wait() + if err != nil { + log.Error(err.Error()) + return drivertype.None, false, errz.Err(err) + } + close(resultCh) + + var highestScore float32 + for res := range resultCh { + if res.score > highestScore { + highestScore = res.score + typ = res.typ + } + } + + const detectScoreThreshold = 0.5 + if highestScore >= detectScoreThreshold { + log.Debug("Type detected", lga.Type, typ, lga.Elapsed, time.Since(start)) + return typ, true, nil + } + + log.Warn("No type detected", lga.Type, typ, lga.Elapsed, time.Since(start)) + return drivertype.None, false, nil +} + +// DriverDetectFunc interrogates a byte stream to determine +// the source driver type. A score is returned indicating +// the confidence that the driver type has been detected. +// A score <= 0 is failure, a score >= 1 is success; intermediate +// values indicate some level of confidence. +// An error is returned only if an IO problem occurred. +// The implementation gets access to the byte stream by invoking openFn, +// and is responsible for closing any reader it opens. +type DriverDetectFunc func(ctx context.Context, openFn FileOpenFunc) ( + detected drivertype.Type, score float32, err error) + +var _ DriverDetectFunc = DetectMagicNumber + +// DetectMagicNumber is a DriverDetectFunc that uses an external +// pkg (h2non/filetype) to detect the "magic number" from +// the start of files. +func DetectMagicNumber(ctx context.Context, openFn FileOpenFunc, +) (detected drivertype.Type, score float32, err error) { + log := lg.FromContext(ctx) + var r io.ReadCloser + r, err = openFn(ctx) + if err != nil { + return drivertype.None, 0, errz.Err(err) + } + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) + + // We only have to pass the file header = first 261 bytes + head := make([]byte, 261) + _, err = r.Read(head) + if err != nil { + return drivertype.None, 0, errz.Wrapf(err, "failed to read header") + } + + ftype, err := filetype.Match(head) + if err != nil { + if err != nil { + return drivertype.None, 0, errz.Err(err) + } + } + + switch ftype { + default: + return drivertype.None, 0, nil + case matchers.TypeXlsx: + // This doesn't seem to work, because .xlsx files are + // zipped, so the type returns as a zip. Perhaps there's + // something we can do about it, such as first extracting + // the zip, and then reading the inner magic number, but + // the xlsx.DetectXLSX func should catch the type anyway. + return typeXLSX, 1.0, nil + case matchers.TypeXls: + // TODO: our xlsx driver doesn't yet support XLS + return typeXLSX, 1.0, errz.Errorf("Microsoft XLS (%s) not currently supported", ftype) + case matchers.TypeSqlite: + return typeSL3, 1.0, nil + } +} diff --git a/libsq/source/download.go b/libsq/source/download.go new file mode 100644 index 000000000..764d1f127 --- /dev/null +++ b/libsq/source/download.go @@ -0,0 +1,46 @@ +package source + +import ( + "context" + "net/url" + "os" + + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/source/fetcher" +) + +// fetch ensures that loc exists locally as a file. This may +// entail downloading the file via HTTPS etc. +func (fs *Files) fetch(ctx context.Context, loc string) (fpath string, err error) { + // This impl is a vestigial abomination from an early + // experiment. + + var ok bool + if fpath, ok = isFpath(loc); ok { + // loc is already a local file path + return fpath, nil + } + + var u *url.URL + if u, ok = httpURL(loc); !ok { + return "", errz.Errorf("not a valid file location: %s", loc) + } + + var dlFile *os.File + dlFile, err = os.CreateTemp("", "") + if err != nil { + return "", errz.Err(err) + } + + fetchr := &fetcher.Fetcher{} + // TOOD: ultimately should be passing a real context here + err = fetchr.Fetch(ctx, u.String(), dlFile) + if err != nil { + return "", errz.Err(err) + } + + // dlFile is kept open until fs is closed. + fs.clnup.AddC(dlFile) + + return dlFile.Name(), nil +} diff --git a/libsq/source/files.go b/libsq/source/files.go index e3047c67c..f5a903095 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -4,16 +4,12 @@ import ( "context" "io" "log/slog" - "mime" "net/url" "os" + "path/filepath" "sync" "time" - "github.com/h2non/filetype" - "github.com/h2non/filetype/matchers" - "golang.org/x/sync/errgroup" - "github.com/neilotoole/fscache" "github.com/neilotoole/sq/libsq/core/cleanup" @@ -24,8 +20,6 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/libsq/source/drivertype" - "github.com/neilotoole/sq/libsq/source/fetcher" ) // Files is the centralized API for interacting with files. @@ -41,6 +35,8 @@ import ( // if we're reading long-running pipe from stdin). This entire thing // needs to be revisited. Maybe Files even becomes a fs.FS. type Files struct { + cacheDir string + tempDir string log *slog.Logger mu sync.Mutex clnup *cleanup.Cleanup @@ -50,19 +46,35 @@ type Files struct { // stdinLength is a func that returns number of bytes read from stdin. // It is nil if stdin has not been read. The func may block until reading // of stdin has completed. + // + // FIXME: This should probably be a map of location to length func, + // because downloaded files can use this mechanism too. + // See Files.Size. stdinLength func() int64 } // NewFiles returns a new Files instance. -func NewFiles(ctx context.Context) (*Files, error) { - fs := &Files{clnup: cleanup.New(), log: lg.FromContext(ctx)} +func NewFiles(ctx context.Context, tmpDir, cacheDir string) (*Files, error) { + if tmpDir == "" { + return nil, errz.Errorf("tmpDir is empty") + } + if cacheDir == "" { + return nil, errz.Errorf("cacheDir is empty") + } - tmpdir, err := os.MkdirTemp("", "sq_files_fscache_*") - if err != nil { + fs := &Files{ + cacheDir: cacheDir, + tempDir: tmpDir, + clnup: cleanup.New(), + log: lg.FromContext(ctx), + } + + fcacheTmpDir := filepath.Join(cacheDir, "fscache") + if err := ioz.RequireDir(fcacheTmpDir); err != nil { return nil, errz.Err(err) } - fcache, err := fscache.New(tmpdir, os.ModePerm, time.Hour) + fcache, err := fscache.New(fcacheTmpDir, os.ModePerm, time.Hour) if err != nil { return nil, errz.Err(err) } @@ -72,32 +84,6 @@ func NewFiles(ctx context.Context) (*Files, error) { return fs, nil } -// AddDriverDetectors adds driver type detectors. -func (fs *Files) AddDriverDetectors(detectFns ...DriverDetectFunc) { - fs.detectFns = append(fs.detectFns, detectFns...) -} - -// DetectStdinType detects the type of stdin as previously added -// by AddStdin. An error is returned if AddStdin was not -// first invoked. If the type cannot be detected, TypeNone and -// nil are returned. -func (fs *Files) DetectStdinType(ctx context.Context) (drivertype.Type, error) { - if !fs.fcache.Exists(StdinHandle) { - return drivertype.None, errz.New("must invoke Files.AddStdin before invoking DetectStdinType") - } - - typ, ok, err := fs.detectType(ctx, StdinHandle) - if err != nil { - return drivertype.None, err - } - - if !ok { - return drivertype.None, nil - } - - return typ, nil -} - // Size returns the file size of src.Location. If the source is being // loaded asynchronously, this function may block until loading completes. func (fs *Files) Size(ctx context.Context, src *Source) (size int64, err error) { @@ -404,42 +390,6 @@ func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) return f, errz.Err(err) } -// fetch ensures that loc exists locally as a file. This may -// entail downloading the file via HTTPS etc. -func (fs *Files) fetch(ctx context.Context, loc string) (fpath string, err error) { - // This impl is a vestigial abomination from an early - // experiment. - - var ok bool - if fpath, ok = isFpath(loc); ok { - // loc is already a local file path - return fpath, nil - } - - var u *url.URL - if u, ok = httpURL(loc); !ok { - return "", errz.Errorf("not a valid file location: %s", loc) - } - - var dlFile *os.File - dlFile, err = os.CreateTemp("", "") - if err != nil { - return "", errz.Err(err) - } - - fetchr := &fetcher.Fetcher{} - // TOOD: ultimately should be passing a real context here - err = fetchr.Fetch(ctx, u.String(), dlFile) - if err != nil { - return "", errz.Err(err) - } - - // dlFile is kept open until fs is closed. - fs.clnup.AddC(dlFile) - - return dlFile.Name(), nil -} - // Close closes any open resources. func (fs *Files) Close() error { fs.log.Debug("Files.Close invoked: executing clean funcs", lga.Count, fs.clnup.Len()) @@ -452,185 +402,10 @@ func (fs *Files) CleanupE(fn func() error) { fs.clnup.AddE(fn) } -// DriverType returns the driver type of loc. -func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, error) { - log := lg.FromContext(ctx).With(lga.Loc, loc) - ploc, err := parseLoc(loc) - if err != nil { - return drivertype.None, err - } - - if ploc.typ != drivertype.None { - return ploc.typ, nil - } - - if ploc.ext != "" { - mtype := mime.TypeByExtension(ploc.ext) - if mtype == "" { - log.Debug("unknown mime type", lga.Type, mtype) - } else { - if typ, ok := typeFromMediaType(mtype); ok { - return typ, nil - } - log.Debug("unknown driver type for media type", lga.Type, mtype) - } - } - - // Fall back to the byte detectors - typ, ok, err := fs.detectType(ctx, loc) - if err != nil { - return drivertype.None, err - } - - if !ok { - return drivertype.None, errz.Errorf("unable to determine driver type: %s", loc) - } - - return typ, nil -} - -func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Type, ok bool, err error) { - if len(fs.detectFns) == 0 { - return drivertype.None, false, nil - } - log := lg.FromContext(ctx).With(lga.Loc, loc) - start := time.Now() - - type result struct { - typ drivertype.Type - score float32 - } - - resultCh := make(chan result, len(fs.detectFns)) - openFn := func(ctx context.Context) (io.ReadCloser, error) { - fs.mu.Lock() - defer fs.mu.Unlock() - - return fs.newReader(ctx, loc) - } - - select { - case <-ctx.Done(): - return drivertype.None, false, ctx.Err() - default: - } - - g, gCtx := errgroup.WithContext(ctx) - - for _, detectFn := range fs.detectFns { - detectFn := detectFn - - g.Go(func() error { - select { - case <-gCtx.Done(): - return gCtx.Err() - default: - } - - gTyp, gScore, gErr := detectFn(gCtx, openFn) - if gErr != nil { - return gErr - } - - if gScore > 0 { - resultCh <- result{typ: gTyp, score: gScore} - } - return nil - }) - } - - // REVISIT: We shouldn't have to wait for all goroutines to complete. - // This logic could be refactored to return as soon as a single - // goroutine returns a score >= 1.0 (then cancelling the other detector - // goroutines). - - err = g.Wait() - if err != nil { - log.Error(err.Error()) - return drivertype.None, false, errz.Err(err) - } - close(resultCh) - - var highestScore float32 - for res := range resultCh { - if res.score > highestScore { - highestScore = res.score - typ = res.typ - } - } - - const detectScoreThreshold = 0.5 - if highestScore >= detectScoreThreshold { - log.Debug("Type detected", lga.Type, typ, lga.Elapsed, time.Since(start)) - return typ, true, nil - } - - log.Warn("No type detected", lga.Type, typ, lga.Elapsed, time.Since(start)) - return drivertype.None, false, nil -} - // FileOpenFunc returns a func that opens a ReadCloser. The caller // is responsible for closing the returned ReadCloser. type FileOpenFunc func(ctx context.Context) (io.ReadCloser, error) -// DriverDetectFunc interrogates a byte stream to determine -// the source driver type. A score is returned indicating -// the confidence that the driver type has been detected. -// A score <= 0 is failure, a score >= 1 is success; intermediate -// values indicate some level of confidence. -// An error is returned only if an IO problem occurred. -// The implementation gets access to the byte stream by invoking openFn, -// and is responsible for closing any reader it opens. -type DriverDetectFunc func(ctx context.Context, openFn FileOpenFunc) ( - detected drivertype.Type, score float32, err error) - -var _ DriverDetectFunc = DetectMagicNumber - -// DetectMagicNumber is a DriverDetectFunc that uses an external -// pkg (h2non/filetype) to detect the "magic number" from -// the start of files. -func DetectMagicNumber(ctx context.Context, openFn FileOpenFunc, -) (detected drivertype.Type, score float32, err error) { - log := lg.FromContext(ctx) - var r io.ReadCloser - r, err = openFn(ctx) - if err != nil { - return drivertype.None, 0, errz.Err(err) - } - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - - // We only have to pass the file header = first 261 bytes - head := make([]byte, 261) - _, err = r.Read(head) - if err != nil { - return drivertype.None, 0, errz.Wrapf(err, "failed to read header") - } - - ftype, err := filetype.Match(head) - if err != nil { - if err != nil { - return drivertype.None, 0, errz.Err(err) - } - } - - switch ftype { - default: - return drivertype.None, 0, nil - case matchers.TypeXlsx: - // This doesn't seem to work, because .xlsx files are - // zipped, so the type returns as a zip. Perhaps there's - // something we can do about it, such as first extracting - // the zip, and then reading the inner magic number, but - // the xlsx.DetectXLSX func should catch the type anyway. - return typeXLSX, 1.0, nil - case matchers.TypeXls: - // TODO: our xlsx driver doesn't yet support XLS - return typeXLSX, 1.0, errz.Errorf("Microsoft XLS (%s) not currently supported", ftype) - case matchers.TypeSqlite: - return typeSL3, 1.0, nil - } -} - // httpURL tests if s is a well-structured HTTP or HTTPS url, and // if so, returns the url and true. func httpURL(s string) (u *url.URL, ok bool) { diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index bc2fd9d2e..eb9990ef7 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -27,7 +27,7 @@ import ( "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/testsrc" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestFiles_Type(t *testing.T) { @@ -54,10 +54,10 @@ func TestFiles_Type(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tutil.Name(source.RedactLocation(tc.loc)), func(t *testing.T) { + t.Run(tu.Name(source.RedactLocation(tc.loc)), func(t *testing.T) { ctx := lg.NewContext(context.Background(), slogt.New(t)) - fs, err := source.NewFiles(ctx) + fs, err := source.NewFiles(ctx, tu.TempDir(t), tu.CacheDir(t)) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -99,7 +99,7 @@ func TestFiles_DetectType(t *testing.T) { t.Run(filepath.Base(tc.loc), func(t *testing.T) { ctx := lg.NewContext(context.Background(), slogt.New(t)) - fs, err := source.NewFiles(ctx) + fs, err := source.NewFiles(ctx, tu.TempDir(t), tu.CacheDir(t)) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -159,7 +159,7 @@ func TestFiles_NewReader(t *testing.T) { Location: proj.Abs(fpath), } - fs, err := source.NewFiles(ctx) + fs, err := source.NewFiles(ctx, tu.TempDir(t), tu.CacheDir(t)) require.NoError(t, err) g := &errgroup.Group{} @@ -195,7 +195,7 @@ func TestFiles_Stdin(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tutil.Name(tc.fpath), func(t *testing.T) { + t.Run(tu.Name(tc.fpath), func(t *testing.T) { th := testh.New(t) fs := th.Files() diff --git a/libsq/source/handle_test.go b/libsq/source/handle_test.go index ce04e11dc..4bc2c335c 100644 --- a/libsq/source/handle_test.go +++ b/libsq/source/handle_test.go @@ -15,7 +15,7 @@ import ( "github.com/neilotoole/sq/drivers/xlsx" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestIsValidGroup(t *testing.T) { @@ -38,7 +38,7 @@ func TestIsValidGroup(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { gotValid := source.IsValidGroup(tc.in) require.Equal(t, tc.valid, gotValid) }) @@ -78,7 +78,7 @@ func TestValidHandle(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.in), func(t *testing.T) { + t.Run(tu.Name(i, tc.in), func(t *testing.T) { gotErr := source.ValidHandle(tc.in) if tc.wantErr { require.Error(t, gotErr) @@ -202,7 +202,7 @@ func TestSuggestHandle(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.typ, tc.loc), func(t *testing.T) { + t.Run(tu.Name(i, tc.typ, tc.loc), func(t *testing.T) { set := &source.Collection{} for i := range tc.taken { err := set.Add(&source.Source{ diff --git a/libsq/source/internal_test.go b/libsq/source/internal_test.go index 051db7abd..b3a408865 100644 --- a/libsq/source/internal_test.go +++ b/libsq/source/internal_test.go @@ -16,7 +16,7 @@ import ( "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/testsrc" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) // Export for testing. @@ -30,7 +30,7 @@ var ( func TestFiles_Open(t *testing.T) { ctx := lg.NewContext(context.Background(), slogt.New(t)) - fs, err := NewFiles(ctx) + fs, err := NewFiles(ctx, tu.TempDir(t), tu.CacheDir(t)) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, fs.Close()) }) @@ -181,7 +181,7 @@ func TestParseLoc(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tutil.Name(1, RedactLocation(tc.loc)), func(t *testing.T) { + t.Run(tu.Name(1, RedactLocation(tc.loc)), func(t *testing.T) { if tc.windows && runtime.GOOS != "windows" { return } @@ -220,7 +220,7 @@ func TestGroupsFilterOnlyDirectChildren(t *testing.T) { for i, tc := range testCases { tc := tc - t.Run(tutil.Name(i, tc.want), func(t *testing.T) { + t.Run(tu.Name(i, tc.want), func(t *testing.T) { got := GroupsFilterOnlyDirectChildren(tc.parent, tc.groups) require.EqualValues(t, tc.want, got) }) diff --git a/libsq/source/location_test.go b/libsq/source/location_test.go index c76fa5ea6..962f5b33a 100644 --- a/libsq/source/location_test.go +++ b/libsq/source/location_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/source" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestIsSQL(t *testing.T) { @@ -73,7 +73,7 @@ func TestLocationWithPassword(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tutil.Name(tc.loc), func(t *testing.T) { + t.Run(tu.Name(tc.loc), func(t *testing.T) { t.Parallel() beforeURL, err := url.ParseRequestURI(tc.loc) diff --git a/libsq/source/source_test.go b/libsq/source/source_test.go index 297558eb8..d623a8b69 100644 --- a/libsq/source/source_test.go +++ b/libsq/source/source_test.go @@ -12,7 +12,7 @@ import ( "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/testh/proj" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) const ( @@ -142,7 +142,7 @@ func TestRedactedLocation(t *testing.T) { for _, tc := range testCases { tc := tc - t.Run(tutil.Name(tc.loc), func(t *testing.T) { + t.Run(tu.Name(tc.loc), func(t *testing.T) { src := &source.Source{Location: tc.loc} got := src.RedactedLocation() t.Logf("%s --> %s", src.Location, got) diff --git a/testh/sakila/sakila_test.go b/testh/sakila/sakila_test.go index 63ec266d8..196e0767b 100644 --- a/testh/sakila/sakila_test.go +++ b/testh/sakila/sakila_test.go @@ -7,7 +7,7 @@ import ( "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) // TestSakila_SQL is a sanity check for Sakila SQL test sources. @@ -34,7 +34,7 @@ func TestSakila_SQL(t *testing.T) { //nolint:tparallel // TestSakila_XLSX is a sanity check for Sakila XLSX test sources. func TestSakila_XLSX(t *testing.T) { - tutil.SkipWindows(t, "XLSX fails on windows pipeline (too slow)") + tu.SkipWindows(t, "XLSX fails on windows pipeline (too slow)") handles := []string{sakila.XLSXSubset} // TODO: Add sakila.XLSX to handles when performance is reasonable diff --git a/testh/testh.go b/testh/testh.go index ee75fb758..1bc8deed6 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -55,6 +55,7 @@ import ( "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/testsrc" + "github.com/neilotoole/sq/testh/tu" ) // defaultDBOpenTimeout is the timeout for tests to open (and ping) their DBs. @@ -164,7 +165,7 @@ func (h *Helper) init() { h.registry = driver.NewRegistry(log) var err error - h.files, err = source.NewFiles(h.Context) + h.files, err = source.NewFiles(h.Context, tu.TempDir(h.T), tu.CacheDir(h.T)) require.NoError(h.T, err) h.Cleanup.Add(func() { @@ -629,11 +630,9 @@ func (h *Helper) QuerySLQ(query string, args map[string]string) (*RecordSink, er } qc := &libsq.QueryContext{ - Collection: h.coll, - PoolOpener: h.sources, - JoinPoolOpener: h.sources, - ScratchPoolOpener: h.sources, - Args: args, + Collection: h.coll, + Sources: h.sources, + Args: args, } sink := &RecordSink{} diff --git a/testh/testh_test.go b/testh/testh_test.go index 213606839..ff7f7d9a6 100644 --- a/testh/testh_test.go +++ b/testh/testh_test.go @@ -16,7 +16,7 @@ import ( "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/tutil" + "github.com/neilotoole/sq/testh/tu" ) func TestVal(t *testing.T) { @@ -173,7 +173,7 @@ func TestTName(t *testing.T) { } for _, tc := range testCases { - got := tutil.Name(tc.a...) + got := tu.Name(tc.a...) require.Equal(t, tc.want, got) } } diff --git a/testh/tutil/tutil.go b/testh/tu/tutil.go similarity index 94% rename from testh/tutil/tutil.go rename to testh/tu/tutil.go index a5bee0fbb..46bcf55e5 100644 --- a/testh/tutil/tutil.go +++ b/testh/tu/tutil.go @@ -1,5 +1,5 @@ -// Package tutil contains basic generic test utilities. -package tutil +// Package tu contains basic generic test utilities. +package tu import ( "fmt" @@ -26,9 +26,9 @@ import ( // // Examples: // -// tutil.SkipIff(t, a == b) -// tutil.SkipIff(t, a == b, "skipping because a == b") -// tutil.SkipIff(t, a == b, "skipping because a is %v and b is %v", a, b) +// tu.SkipIff(t, a == b) +// tu.SkipIff(t, a == b, "skipping because a == b") +// tu.SkipIff(t, a == b, "skipping because a is %v and b is %v", a, b) func SkipIff(t testing.TB, b bool, format string, args ...any) { if b { if format == "" { @@ -379,3 +379,13 @@ func MustAbsFilepath(elems ...string) string { } return s } + +// TempDir is the standard means for obtaining a temp dir for tests. +func TempDir(t testing.TB) string { + return filepath.Join(t.TempDir(), "sq", "tmp") +} + +// CacheDir is the standard means for obtaining a cache dir for tests. +func CacheDir(t testing.TB) string { + return filepath.Join(t.TempDir(), "sq", "cache") +} diff --git a/testh/tutil/tutil_test.go b/testh/tu/tutil_test.go similarity index 99% rename from testh/tutil/tutil_test.go rename to testh/tu/tutil_test.go index 759c1a9b0..c4f98437e 100644 --- a/testh/tutil/tutil_test.go +++ b/testh/tu/tutil_test.go @@ -1,4 +1,4 @@ -package tutil +package tu import ( "testing" From 4922c1f8174fdee0eef5439fe98cf3d03166fd11 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 22:27:55 -0700 Subject: [PATCH 049/195] All drivers tests now working --- drivers/userdriver/userdriver_test.go | 9 +-------- libsq/driver/{sources.go => grips.go} | 26 ++++++++++++++++++++------ libsq/pipeline.go | 5 ++--- libsq/source/source.go | 14 +------------- testh/testh.go | 10 +++------- 5 files changed, 27 insertions(+), 37 deletions(-) rename libsq/driver/{sources.go => grips.go} (95%) diff --git a/drivers/userdriver/userdriver_test.go b/drivers/userdriver/userdriver_test.go index df2c757a5..2525bac7c 100644 --- a/drivers/userdriver/userdriver_test.go +++ b/drivers/userdriver/userdriver_test.go @@ -3,7 +3,6 @@ package userdriver_test import ( "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/neilotoole/sq/cli/config" @@ -35,13 +34,7 @@ func TestDriver(t *testing.T) { th := testh.New(t, testh.OptLongOpen()) src := th.Source(tc.handle) - drvr := th.DriverFor(src) - err := drvr.Ping(th.Context, src) - require.NoError(t, err) - - pool, err := drvr.Open(th.Context, src) - require.NoError(t, err) - t.Cleanup(func() { assert.NoError(t, pool.Close()) }) + pool := th.Open(src) srcMeta, err := pool.SourceMetadata(th.Context, false) require.NoError(t, err) diff --git a/libsq/driver/sources.go b/libsq/driver/grips.go similarity index 95% rename from libsq/driver/sources.go rename to libsq/driver/grips.go index eb2190cae..44ced2f31 100644 --- a/libsq/driver/sources.go +++ b/libsq/driver/grips.go @@ -98,9 +98,13 @@ func (ss *Sources) IsSQLSource(src *source.Source) bool { return false } +func (ss *Sources) getKey(src *source.Source) string { + return src.Handle + "_" + src.Hash() +} + func (ss *Sources) doOpen(ctx context.Context, src *source.Source) (Pool, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) - key := src.Handle + "_" + src.Hash() + key := ss.getKey(src) pool, ok := ss.pools[key] if ok { @@ -172,16 +176,26 @@ func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Pool, e return backingPool, nil } -// OpenIngest implements driver.ScratchPoolOpener. +// OpenIngest opens a pool for src, using ingestFn to ingest +// the source data if necessary. func (ss *Sources) OpenIngest(ctx context.Context, src *source.Source, allowCache bool, ingestFn func(ctx context.Context, dest Pool) error, ) (Pool, error) { + var pool Pool + var err error + if !allowCache || src.Handle == source.StdinHandle { // We don't currently cache stdin. Probably we never will? - return ss.openIngestNoCache(ctx, src, ingestFn) + pool, err = ss.openIngestNoCache(ctx, src, ingestFn) + } else { + pool, err = ss.openIngestCache(ctx, src, ingestFn) + } + + if err != nil { + return nil, err } - return ss.openIngestCache(ctx, src, ingestFn) + return pool, nil } func (ss *Sources) openIngestNoCache(ctx context.Context, src *source.Source, @@ -206,7 +220,7 @@ func (ss *Sources) openIngestNoCache(ctx context.Context, src *source.Source, return nil, err } - ss.log.Debug("Ingest completed", + ss.log.Info("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) return impl, nil @@ -280,7 +294,7 @@ func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, return nil, err } - log.Debug("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) + log.Info("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) // Write the checksums file. var sum ioz.Checksum diff --git a/libsq/pipeline.go b/libsq/pipeline.go index c27c2ca9a..264eceb90 100644 --- a/libsq/pipeline.go +++ b/libsq/pipeline.go @@ -189,9 +189,8 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { log.Debug("No active SQL source, will use scratchdb.") // REVISIT: ScratchPoolOpener needs a source, so we just make one up. ephemeralSrc := &source.Source{ - Type: drivertype.None, - Handle: "@scratch_" + stringz.Uniq8(), - Ephemeral: true, + Type: drivertype.None, + Handle: "@scratch_" + stringz.Uniq8(), } // FIXME: We really want to change the signature of OpenScratch to diff --git a/libsq/source/source.go b/libsq/source/source.go index bd639e3ae..39a73085a 100644 --- a/libsq/source/source.go +++ b/libsq/source/source.go @@ -14,7 +14,6 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/options" - "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source/drivertype" ) @@ -92,18 +91,10 @@ type Source struct { // Options are additional params, typically empty. Options options.Options `yaml:"options,omitempty" json:"options,omitempty"` - - // Ephemeral is a flag that indicates that the source is ephemeral. This - // value is not persisted to config. It is used by the Source.Hash method, - // resulting in a different hash value for each ephemeral source. - Ephemeral bool } // Hash returns an SHA256 hash of all fields of s. The Source.Options -// field is ignored. If s is nil, the empty string is returned. If -// Source.Ephemeral is true, the hash value will be different for -// each invocation. This is useful for preventing cache collisions -// when using ephemeral sources. +// field is ignored. If s is nil, the empty string is returned. func (s *Source) Hash() string { if s == nil { return "" @@ -116,9 +107,6 @@ func (s *Source) Hash() string { buf.WriteString(s.Catalog) buf.WriteString(s.Schema) buf.WriteString(s.Options.Hash()) - if s.Ephemeral { - buf.WriteString(stringz.Uniq32()) - } sum := sha256.Sum256(buf.Bytes()) return fmt.Sprintf("%x", sum) diff --git a/testh/testh.go b/testh/testh.go index 1bc8deed6..764293f20 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -14,6 +14,8 @@ import ( "testing" "time" + "github.com/neilotoole/sq/libsq/core/lg/devlog" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -65,12 +67,7 @@ const defaultDBOpenTimeout = time.Second * 5 func init() { //nolint:gochecknoinits slogt.Default = slogt.Factory(func(w io.Writer) slog.Handler { - h := &slog.HandlerOptions{ - Level: slog.LevelDebug, - AddSource: true, - } - - return slog.NewTextHandler(w, h) + return devlog.NewHandler(w, slog.LevelDebug) }) } @@ -312,7 +309,6 @@ func (h *Helper) Source(handle string) *source.Source { src, err := h.coll.Get(handle) require.NoError(t, err, "source %s was not found in %s", handle, testsrc.PathSrcsConfig) - src.Ephemeral = true if src.Type == sqlite3.Type { // This could be easily generalized for CSV/XLSX etc. From 3e7d0741939d9c86fcdf716b5174679f9c479e2b Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 23:10:48 -0700 Subject: [PATCH 050/195] Renamed driver.Pool to driver.Grip --- cli/cmd_inspect.go | 20 ++--- cli/cmd_slq.go | 12 +-- cli/cmd_sql.go | 18 ++--- cli/cmd_tbl.go | 12 +-- cli/complete.go | 8 +- cli/diff/source.go | 4 +- cli/diff/table.go | 4 +- cli/output/adapter_test.go | 4 +- cli/run/run.go | 2 +- drivers/csv/csv.go | 44 +++++------ drivers/csv/ingest.go | 14 ++-- drivers/csv/insert.go | 6 +- drivers/json/ingest.go | 12 +-- drivers/json/ingest_json.go | 4 +- drivers/json/ingest_jsona.go | 14 ++-- drivers/json/ingest_jsonl.go | 4 +- drivers/json/ingest_test.go | 8 +- drivers/json/internal_test.go | 4 +- drivers/json/json.go | 68 ++++++++--------- drivers/mysql/metadata.go | 4 +- drivers/mysql/metadata_test.go | 8 +- drivers/mysql/mysql.go | 48 ++++++------ drivers/postgres/postgres.go | 50 ++++++------- drivers/postgres/postgres_test.go | 10 +-- drivers/sqlite3/metadata_test.go | 8 +- drivers/sqlite3/sqlite3.go | 71 +++++++++--------- drivers/sqlserver/sqlserver.go | 52 ++++++------- drivers/userdriver/userdriver.go | 68 ++++++++--------- drivers/userdriver/userdriver_test.go | 6 +- drivers/userdriver/xmlud/xmlimport.go | 24 +++--- drivers/xlsx/grip.go | 73 ++++++++++++++++++ drivers/xlsx/ingest.go | 27 +++---- drivers/xlsx/pool.go | 73 ------------------ drivers/xlsx/xlsx.go | 12 +-- drivers/xlsx/xlsx_test.go | 4 +- libsq/dbwriter.go | 34 ++++----- libsq/driver/driver.go | 53 ++++++------- libsq/driver/driver_test.go | 30 ++++---- libsq/driver/grips.go | 73 +++++++++--------- libsq/libsq.go | 16 ++-- libsq/pipeline.go | 98 ++++++++++++------------ libsq/prepare.go | 8 +- testh/testh.go | 104 +++++++++++++------------- 43 files changed, 606 insertions(+), 610 deletions(-) create mode 100644 drivers/xlsx/grip.go delete mode 100644 drivers/xlsx/pool.go diff --git a/cli/cmd_inspect.go b/cli/cmd_inspect.go index 2a64d1806..05c543604 100644 --- a/cli/cmd_inspect.go +++ b/cli/cmd_inspect.go @@ -125,7 +125,7 @@ func execInspect(cmd *cobra.Command, args []string) error { return err } - pool, err := ru.Sources.Open(ctx, src) + grip, err := ru.Sources.Open(ctx, src) if err != nil { return errz.Wrapf(err, "failed to inspect %s", src.Handle) } @@ -142,7 +142,7 @@ func execInspect(cmd *cobra.Command, args []string) error { } var tblMeta *metadata.Table - tblMeta, err = pool.TableMetadata(ctx, table) + tblMeta, err = grip.TableMetadata(ctx, table) if err != nil { return err } @@ -152,11 +152,11 @@ func execInspect(cmd *cobra.Command, args []string) error { if cmdFlagIsSetTrue(cmd, flag.InspectCatalogs) { var db *sql.DB - if db, err = pool.DB(ctx); err != nil { + if db, err = grip.DB(ctx); err != nil { return err } var catalogs []string - if catalogs, err = pool.SQLDriver().ListCatalogs(ctx, db); err != nil { + if catalogs, err = grip.SQLDriver().ListCatalogs(ctx, db); err != nil { return err } @@ -171,16 +171,16 @@ func execInspect(cmd *cobra.Command, args []string) error { if cmdFlagIsSetTrue(cmd, flag.InspectSchemata) { var db *sql.DB - if db, err = pool.DB(ctx); err != nil { + if db, err = grip.DB(ctx); err != nil { return err } var schemas []*metadata.Schema - if schemas, err = pool.SQLDriver().ListSchemaMetadata(ctx, db); err != nil { + if schemas, err = grip.SQLDriver().ListSchemaMetadata(ctx, db); err != nil { return err } var currentSchema string - if currentSchema, err = pool.SQLDriver().CurrentSchema(ctx, db); err != nil { + if currentSchema, err = grip.SQLDriver().CurrentSchema(ctx, db); err != nil { return err } @@ -189,12 +189,12 @@ func execInspect(cmd *cobra.Command, args []string) error { if cmdFlagIsSetTrue(cmd, flag.InspectDBProps) { var db *sql.DB - if db, err = pool.DB(ctx); err != nil { + if db, err = grip.DB(ctx); err != nil { return err } defer lg.WarnIfCloseError(log, lgm.CloseDB, db) var props map[string]any - if props, err = pool.SQLDriver().DBProperties(ctx, db); err != nil { + if props, err = grip.SQLDriver().DBProperties(ctx, db); err != nil { return err } @@ -203,7 +203,7 @@ func execInspect(cmd *cobra.Command, args []string) error { overviewOnly := cmdFlagIsSetTrue(cmd, flag.InspectOverview) - srcMeta, err := pool.SourceMetadata(ctx, overviewOnly) + srcMeta, err := grip.SourceMetadata(ctx, overviewOnly) if err != nil { return errz.Wrapf(err, "failed to read %s source metadata", src.Handle) } diff --git a/cli/cmd_slq.go b/cli/cmd_slq.go index 9c47435e5..e7c7efec1 100644 --- a/cli/cmd_slq.go +++ b/cli/cmd_slq.go @@ -135,7 +135,7 @@ func execSLQInsert(ctx context.Context, ru *run.Run, mArgs map[string]string, ctx, cancelFn := context.WithCancel(ctx) defer cancelFn() - destPool, err := ru.Sources.Open(ctx, destSrc) + destGrip, err := ru.Sources.Open(ctx, destSrc) if err != nil { return err } @@ -147,7 +147,7 @@ func execSLQInsert(ctx context.Context, ru *run.Run, mArgs map[string]string, inserter := libsq.NewDBWriter( "Insert records", - destPool, + destGrip, destTbl, driver.OptTuningRecChanSize.Get(destSrc.Options), libsq.DBWriterCreateTableIfNotExistsHook(destTbl), @@ -204,7 +204,7 @@ func execSLQPrint(ctx context.Context, ru *run.Run, mArgs map[string]string) err // // $ cat something.xlsx | sq @stdin.sheet1 func preprocessUserSLQ(ctx context.Context, ru *run.Run, args []string) (string, error) { - log, reg, pools, coll := lg.FromContext(ctx), ru.DriverRegistry, ru.Sources, ru.Config.Collection + log, reg, grips, coll := lg.FromContext(ctx), ru.DriverRegistry, ru.Sources, ru.Config.Collection activeSrc := coll.Active() if len(args) == 0 { @@ -235,13 +235,13 @@ func preprocessUserSLQ(ctx context.Context, ru *run.Run, args []string) (string, // This isn't a monotable src, so we can't // just select @stdin.data. Instead we'll select // the first table name, as found in the source meta. - pool, err := pools.Open(ctx, activeSrc) + grip, err := grips.Open(ctx, activeSrc) if err != nil { return "", err } - defer lg.WarnIfCloseError(log, lgm.CloseDB, pool) + defer lg.WarnIfCloseError(log, lgm.CloseDB, grip) - srcMeta, err := pool.SourceMetadata(ctx, false) + srcMeta, err := grip.SourceMetadata(ctx, false) if err != nil { return "", err } diff --git a/cli/cmd_sql.go b/cli/cmd_sql.go index e2bcf9a0e..a14c58660 100644 --- a/cli/cmd_sql.go +++ b/cli/cmd_sql.go @@ -120,13 +120,13 @@ func execSQL(cmd *cobra.Command, args []string) error { // to the configured writer. func execSQLPrint(ctx context.Context, ru *run.Run, fromSrc *source.Source) error { args := ru.Args - pool, err := ru.Sources.Open(ctx, fromSrc) + grip, err := ru.Sources.Open(ctx, fromSrc) if err != nil { return err } recw := output.NewRecordWriterAdapter(ctx, ru.Writers.Record) - err = libsq.QuerySQL(ctx, pool, nil, recw, args[0]) + err = libsq.QuerySQL(ctx, grip, nil, recw, args[0]) if err != nil { return err } @@ -140,32 +140,32 @@ func execSQLInsert(ctx context.Context, ru *run.Run, fromSrc, destSrc *source.Source, destTbl string, ) error { args := ru.Args - pools := ru.Sources + grips := ru.Sources ctx, cancelFn := context.WithCancel(ctx) defer cancelFn() - fromPool, err := pools.Open(ctx, fromSrc) + fromGrip, err := grips.Open(ctx, fromSrc) if err != nil { return err } - destPool, err := pools.Open(ctx, destSrc) + destGrip, err := grips.Open(ctx, destSrc) if err != nil { return err } - // Note: We don't need to worry about closing fromPool and - // destPool because they are closed by pools.Close, which + // Note: We don't need to worry about closing fromGrip and + // destGrip because they are closed by grips.Close, which // is invoked by ru.Close, and ru is closed further up the // stack. inserter := libsq.NewDBWriter( "Insert records", - destPool, + destGrip, destTbl, driver.OptTuningRecChanSize.Get(destSrc.Options), libsq.DBWriterCreateTableIfNotExistsHook(destTbl), ) - err = libsq.QuerySQL(ctx, fromPool, nil, inserter, args[0]) + err = libsq.QuerySQL(ctx, fromGrip, nil, inserter, args[0]) if err != nil { return errz.Wrapf(err, "insert to {%s} failed", source.Target(destSrc, destTbl)) } diff --git a/cli/cmd_tbl.go b/cli/cmd_tbl.go index 52ff32504..c92b3f713 100644 --- a/cli/cmd_tbl.go +++ b/cli/cmd_tbl.go @@ -121,13 +121,13 @@ func execTblCopy(cmd *cobra.Command, args []string) error { return err } - var pool driver.Pool - pool, err = ru.Sources.Open(ctx, tblHandles[0].src) + var grip driver.Grip + grip, err = ru.Sources.Open(ctx, tblHandles[0].src) if err != nil { return err } - db, err := pool.DB(ctx) + db, err := grip.DB(ctx) if err != nil { return err } @@ -254,13 +254,13 @@ func execTblDrop(cmd *cobra.Command, args []string) (err error) { return errz.Errorf("driver type {%s} (%s) doesn't support dropping tables", tblH.src.Type, tblH.src.Handle) } - var pool driver.Pool - if pool, err = ru.Sources.Open(ctx, tblH.src); err != nil { + var grip driver.Grip + if grip, err = ru.Sources.Open(ctx, tblH.src); err != nil { return err } var db *sql.DB - if db, err = pool.DB(ctx); err != nil { + if db, err = grip.DB(ctx); err != nil { return err } diff --git a/cli/complete.go b/cli/complete.go index 72d027283..0959cbb8e 100644 --- a/cli/complete.go +++ b/cli/complete.go @@ -383,13 +383,13 @@ func (c activeSchemaCompleter) complete(cmd *cobra.Command, args []string, toCom ctx, cancelFn := context.WithTimeout(cmd.Context(), OptShellCompletionTimeout.Get(ru.Config.Options)) defer cancelFn() - pool, err := ru.Sources.Open(ctx, src) + grip, err := ru.Sources.Open(ctx, src) if err != nil { lg.Unexpected(log, err) return nil, cobra.ShellCompDirectiveError } - db, err := pool.DB(ctx) + db, err := grip.DB(ctx) if err != nil { lg.Unexpected(log, err) return nil, cobra.ShellCompDirectiveError @@ -759,14 +759,14 @@ func getTableNamesForHandle(ctx context.Context, ru *run.Run, handle string) ([] return nil, err } - pool, err := ru.Sources.Open(ctx, src) + grip, err := ru.Sources.Open(ctx, src) if err != nil { return nil, err } // TODO: We shouldn't have to load the full metadata just to get // the table names. driver.SQLDriver should have a method ListTables. - md, err := pool.SourceMetadata(ctx, false) + md, err := grip.SourceMetadata(ctx, false) if err != nil { return nil, err } diff --git a/cli/diff/source.go b/cli/diff/source.go index caeb6e895..baba129e1 100644 --- a/cli/diff/source.go +++ b/cli/diff/source.go @@ -196,11 +196,11 @@ func fetchSourceMeta(ctx context.Context, ru *run.Run, handle string) (*source.S if err != nil { return nil, nil, err } - pool, err := ru.Sources.Open(ctx, src) + grip, err := ru.Sources.Open(ctx, src) if err != nil { return nil, nil, err } - md, err := pool.SourceMetadata(ctx, false) + md, err := grip.SourceMetadata(ctx, false) if err != nil { return nil, nil, err } diff --git a/cli/diff/table.go b/cli/diff/table.go index 11a347086..6e6531c15 100644 --- a/cli/diff/table.go +++ b/cli/diff/table.go @@ -114,11 +114,11 @@ func buildTableStructureDiff(cfg *Config, showRowCounts bool, td1, td2 *tableDat func fetchTableMeta(ctx context.Context, ru *run.Run, src *source.Source, table string) ( *metadata.Table, error, ) { - pool, err := ru.Sources.Open(ctx, src) + grip, err := ru.Sources.Open(ctx, src) if err != nil { return nil, err } - md, err := pool.TableMetadata(ctx, table) + md, err := grip.TableMetadata(ctx, table) if err != nil { if errz.IsErrNotExist(err) { return nil, nil //nolint:nilnil diff --git a/cli/output/adapter_test.go b/cli/output/adapter_test.go index dcee9c3cb..434ffaa14 100644 --- a/cli/output/adapter_test.go +++ b/cli/output/adapter_test.go @@ -50,11 +50,11 @@ func TestRecordWriterAdapter(t *testing.T) { th := testh.New(t) src := th.Source(tc.handle) - pool := th.Open(src) + grip := th.Open(src) sink := &testh.RecordSink{} recw := output.NewRecordWriterAdapter(th.Context, sink) - err := libsq.QuerySQL(th.Context, pool, nil, recw, tc.sqlQuery) + err := libsq.QuerySQL(th.Context, grip, nil, recw, tc.sqlQuery) require.NoError(t, err) written, err := recw.Wait() require.NoError(t, err) diff --git a/cli/run/run.go b/cli/run/run.go index eed1fcc7c..e5c76ebbc 100644 --- a/cli/run/run.go +++ b/cli/run/run.go @@ -73,7 +73,7 @@ type Run struct { // Files manages file access. Files *source.Files - // Sources mediates access to db pools. + // Sources mediates access to driver.Grip instances. Sources *driver.Sources // Writers holds the various writer types that diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index 91b0eab5f..afe6eca04 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -65,12 +65,12 @@ func (d *driveri) DriverMetadata() driver.Metadata { return md } -// Open implements driver.PoolOpener. -func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { +// Open implements driver.GripOpener. +func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { log := lg.FromContext(ctx) log.Debug(lgm.OpenSrc, lga.Src, src) - p := &pool{ + g := &grip{ log: d.log, src: src, files: d.files, @@ -78,18 +78,18 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) - ingestFn := func(ctx context.Context, destPool driver.Pool) error { + ingestFn := func(ctx context.Context, destGrip driver.Grip) error { openFn := d.files.OpenFunc(src) log.Debug("Ingest func invoked", lga.Src, src) - return ingestCSV(ctx, src, openFn, destPool) + return ingestCSV(ctx, src, openFn, destGrip) } var err error - if p.impl, err = d.ingester.OpenIngest(ctx, src, allowCache, ingestFn); err != nil { + if g.impl, err = d.ingester.OpenIngest(ctx, src, allowCache, ingestFn); err != nil { return nil, err } - return p, nil + return g, nil } // Truncate implements driver.Driver. @@ -120,31 +120,31 @@ func (d *driveri) Ping(ctx context.Context, src *source.Source) error { return nil } -// pool implements driver.Pool. -type pool struct { +// grip implements driver.Grip. +type grip struct { log *slog.Logger src *source.Source - impl driver.Pool + impl driver.Grip files *source.Files } -// DB implements driver.Pool. -func (p *pool) DB(ctx context.Context) (*sql.DB, error) { +// DB implements driver.Grip. +func (p *grip) DB(ctx context.Context) (*sql.DB, error) { return p.impl.DB(ctx) } -// SQLDriver implements driver.Pool. -func (p *pool) SQLDriver() driver.SQLDriver { +// SQLDriver implements driver.Grip. +func (p *grip) SQLDriver() driver.SQLDriver { return p.impl.SQLDriver() } -// Source implements driver.Pool. -func (p *pool) Source() *source.Source { +// Source implements driver.Grip. +func (p *grip) Source() *source.Source { return p.src } -// TableMetadata implements driver.Pool. -func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { +// TableMetadata implements driver.Grip. +func (p *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { if tblName != source.MonotableName { return nil, errz.Errorf("table name should be %s for CSV/TSV etc., but got: %s", source.MonotableName, tblName) @@ -159,8 +159,8 @@ func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Tab return srcMeta.Tables[0], nil } -// SourceMetadata implements driver.Pool. -func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { +// SourceMetadata implements driver.Grip. +func (p *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { md, err := p.impl.SourceMetadata(ctx, noSchema) if err != nil { return nil, err @@ -184,8 +184,8 @@ func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return md, nil } -// Close implements driver.Pool. -func (p *pool) Close() error { +// Close implements driver.Grip. +func (p *grip) Close() error { p.log.Debug(lgm.CloseDB, lga.Handle, p.src.Handle) return errz.Err(p.impl.Close()) diff --git a/drivers/csv/ingest.go b/drivers/csv/ingest.go index d5cd53d0f..b4cde9009 100644 --- a/drivers/csv/ingest.go +++ b/drivers/csv/ingest.go @@ -51,7 +51,7 @@ Possible values are: comma, space, pipe, tab, colon, semi, period.`, ) // ingestCSV loads the src CSV data into scratchDB. -func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFunc, destPool driver.Pool) error { +func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFunc, destGrip driver.Grip) error { log := lg.FromContext(ctx) startUTC := time.Now().UTC() @@ -107,17 +107,17 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu // And now we need to create the dest table in scratchDB tblDef := createTblDef(source.MonotableName, header, kinds) - db, err := destPool.DB(ctx) + db, err := destGrip.DB(ctx) if err != nil { return err } - err = destPool.SQLDriver().CreateTable(ctx, db, tblDef) + err = destGrip.SQLDriver().CreateTable(ctx, db, tblDef) if err != nil { return errz.Wrap(err, "csv: failed to create dest scratch table") } - recMeta, err := getIngestRecMeta(ctx, destPool, tblDef) + recMeta, err := getIngestRecMeta(ctx, destGrip, tblDef) if err != nil { return err } @@ -128,9 +128,9 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu insertWriter := libsq.NewDBWriter( libsq.MsgIngestRecords, - destPool, + destGrip, tblDef.Name, - driver.OptTuningRecChanSize.Get(destPool.Source().Options), + driver.OptTuningRecChanSize.Get(destGrip.Source().Options), ) err = execInsert(ctx, insertWriter, recMeta, mungers, recs, cr) if err != nil { @@ -145,7 +145,7 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu log.Debug("Inserted rows", lga.Count, inserted, lga.Elapsed, time.Since(startUTC).Round(time.Millisecond), - lga.Target, source.Target(destPool.Source(), tblDef.Name), + lga.Target, source.Target(destGrip.Source(), tblDef.Name), ) return nil } diff --git a/drivers/csv/insert.go b/drivers/csv/insert.go index 412cdba3d..ed12225da 100644 --- a/drivers/csv/insert.go +++ b/drivers/csv/insert.go @@ -118,13 +118,13 @@ func createTblDef(tblName string, colNames []string, kinds []kind.Kind) *sqlmode } // getIngestRecMeta returns record.Meta to use with RecordWriter.Open. -func getIngestRecMeta(ctx context.Context, destPool driver.Pool, tblDef *sqlmodel.TableDef) (record.Meta, error) { - db, err := destPool.DB(ctx) +func getIngestRecMeta(ctx context.Context, destGrip driver.Grip, tblDef *sqlmodel.TableDef) (record.Meta, error) { + db, err := destGrip.DB(ctx) if err != nil { return nil, err } - drvr := destPool.SQLDriver() + drvr := destGrip.SQLDriver() colTypes, err := drvr.TableColumnTypes(ctx, db, tblDef.Name, tblDef.ColNames()) if err != nil { diff --git a/drivers/json/ingest.go b/drivers/json/ingest.go index 2afc8e3d6..e798bf4e5 100644 --- a/drivers/json/ingest.go +++ b/drivers/json/ingest.go @@ -26,11 +26,11 @@ import ( // ingestJob describes a single ingest job, where the JSON // at fromSrc is read via openFn and the resulting records -// are written to destPool. +// are written to destGrip. type ingestJob struct { fromSrc *source.Source openFn source.FileOpenFunc - destPool driver.Pool + destGrip driver.Grip // sampleSize is the maximum number of values to // sample to determine the kind of an element. @@ -53,18 +53,18 @@ var ( ) // getRecMeta returns record.Meta to use with RecordWriter.Open. -func getRecMeta(ctx context.Context, pool driver.Pool, tblDef *sqlmodel.TableDef) (record.Meta, error) { - db, err := pool.DB(ctx) +func getRecMeta(ctx context.Context, grip driver.Grip, tblDef *sqlmodel.TableDef) (record.Meta, error) { + db, err := grip.DB(ctx) if err != nil { return nil, err } - colTypes, err := pool.SQLDriver().TableColumnTypes(ctx, db, tblDef.Name, tblDef.ColNames()) + colTypes, err := grip.SQLDriver().TableColumnTypes(ctx, db, tblDef.Name, tblDef.ColNames()) if err != nil { return nil, err } - destMeta, _, err := pool.SQLDriver().RecordMeta(ctx, colTypes) + destMeta, _, err := grip.SQLDriver().RecordMeta(ctx, colTypes) if err != nil { return nil, err } diff --git a/drivers/json/ingest_json.go b/drivers/json/ingest_json.go index bca1bb7fc..592e76bfc 100644 --- a/drivers/json/ingest_json.go +++ b/drivers/json/ingest_json.go @@ -141,9 +141,9 @@ func ingestJSON(ctx context.Context, job ingestJob) error { } defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - drvr := job.destPool.SQLDriver() + drvr := job.destGrip.SQLDriver() - db, err := job.destPool.DB(ctx) + db, err := job.destGrip.DB(ctx) if err != nil { return err } diff --git a/drivers/json/ingest_jsona.go b/drivers/json/ingest_jsona.go index 0c21f3216..28e3a8af3 100644 --- a/drivers/json/ingest_jsona.go +++ b/drivers/json/ingest_jsona.go @@ -119,19 +119,19 @@ func ingestJSONA(ctx context.Context, job ingestJob) error { colNames[i] = stringz.GenerateAlphaColName(i, true) } - // And now we need to create the dest table in destPool + // And now we need to create the dest table in destGrip tblDef := sqlmodel.NewTableDef(source.MonotableName, colNames, colKinds) - db, err := job.destPool.DB(ctx) + db, err := job.destGrip.DB(ctx) if err != nil { return err } - err = job.destPool.SQLDriver().CreateTable(ctx, db, tblDef) + err = job.destGrip.SQLDriver().CreateTable(ctx, db, tblDef) if err != nil { return errz.Wrapf(err, "import %s: failed to create dest scratch table", TypeJSONA) } - recMeta, err := getRecMeta(ctx, job.destPool, tblDef) + recMeta, err := getRecMeta(ctx, job.destGrip, tblDef) if err != nil { return err } @@ -144,9 +144,9 @@ func ingestJSONA(ctx context.Context, job ingestJob) error { insertWriter := libsq.NewDBWriter( libsq.MsgIngestRecords, - job.destPool, + job.destGrip, tblDef.Name, - driver.OptTuningRecChanSize.Get(job.destPool.Source().Options), + driver.OptTuningRecChanSize.Get(job.destGrip.Source().Options), ) var cancelFn context.CancelFunc @@ -172,7 +172,7 @@ func ingestJSONA(ctx context.Context, job ingestJob) error { log.Debug("Inserted rows", lga.Count, inserted, - lga.Target, source.Target(job.destPool.Source(), tblDef.Name), + lga.Target, source.Target(job.destGrip.Source(), tblDef.Name), ) return nil } diff --git a/drivers/json/ingest_jsonl.go b/drivers/json/ingest_jsonl.go index 07484e02c..7e8c36a33 100644 --- a/drivers/json/ingest_jsonl.go +++ b/drivers/json/ingest_jsonl.go @@ -92,8 +92,8 @@ func ingestJSONL(ctx context.Context, job ingestJob) error { //nolint:gocognit } defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - drvr := job.destPool.SQLDriver() - db, err := job.destPool.DB(ctx) + drvr := job.destGrip.SQLDriver() + db, err := job.destGrip.DB(ctx) if err != nil { return err } diff --git a/drivers/json/ingest_test.go b/drivers/json/ingest_test.go index 42348631b..16b18a6b1 100644 --- a/drivers/json/ingest_test.go +++ b/drivers/json/ingest_test.go @@ -85,8 +85,8 @@ func TestImportJSONL_Flat(t *testing.T) { } } - th, src, _, pool, _ := testh.NewWith(t, testsrc.EmptyDB) - job := json.NewImportJob(src, openFn, pool, 0, true) + th, src, _, grip, _ := testh.NewWith(t, testsrc.EmptyDB) + job := json.NewImportJob(src, openFn, grip, 0, true) err := json.ImportJSONL(th.Context, job) if tc.wantErr { @@ -110,8 +110,8 @@ func TestImportJSON_Flat(t *testing.T) { return os.Open("testdata/actor.json") } - th, src, _, pool, _ := testh.NewWith(t, testsrc.EmptyDB) - job := json.NewImportJob(src, openFn, pool, 0, true) + th, src, _, grip, _ := testh.NewWith(t, testsrc.EmptyDB) + job := json.NewImportJob(src, openFn, grip, 0, true) err := json.ImportJSON(th.Context, job) require.NoError(t, err) diff --git a/drivers/json/internal_test.go b/drivers/json/internal_test.go index c95a64835..c35e56ac3 100644 --- a/drivers/json/internal_test.go +++ b/drivers/json/internal_test.go @@ -26,7 +26,7 @@ var ( // newImportJob is a constructor for the unexported ingestJob type. // If sampleSize <= 0, a default value is used. -func newImportJob(fromSrc *source.Source, openFn source.FileOpenFunc, destPool driver.Pool, sampleSize int, +func newImportJob(fromSrc *source.Source, openFn source.FileOpenFunc, destGrip driver.Grip, sampleSize int, flatten bool, ) ingestJob { if sampleSize <= 0 { @@ -36,7 +36,7 @@ func newImportJob(fromSrc *source.Source, openFn source.FileOpenFunc, destPool d return ingestJob{ fromSrc: fromSrc, openFn: openFn, - destPool: destPool, + destGrip: destGrip, sampleSize: sampleSize, flatten: flatten, } diff --git a/drivers/json/json.go b/drivers/json/json.go index df5a993a6..240549be1 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -90,12 +90,12 @@ func (d *driveri) DriverMetadata() driver.Metadata { return md } -// Open implements driver.PoolOpener. -func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { +// Open implements driver.GripOpener. +func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { log := lg.FromContext(ctx) log.Debug(lgm.OpenSrc, lga.Src, src) - p := &pool{ + g := &grip{ log: log, src: src, clnup: cleanup.New(), @@ -104,11 +104,11 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) - ingestFn := func(ctx context.Context, destPool driver.Pool) error { + ingestFn := func(ctx context.Context, destGrip driver.Grip) error { job := ingestJob{ fromSrc: src, openFn: d.files.OpenFunc(src), - destPool: destPool, + destGrip: destGrip, sampleSize: driver.OptIngestSampleSize.Get(src.Options), flatten: true, } @@ -117,11 +117,11 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er } var err error - if p.impl, err = d.ingester.OpenIngest(ctx, src, allowCache, ingestFn); err != nil { + if g.impl, err = d.ingester.OpenIngest(ctx, src, allowCache, ingestFn); err != nil { return nil, err } - return p, nil + return g, nil } // Truncate implements driver.Driver. @@ -153,38 +153,38 @@ func (d *driveri) Ping(ctx context.Context, src *source.Source) error { return nil } -// pool implements driver.Pool. -type pool struct { +// grip implements driver.Grip. +type grip struct { log *slog.Logger src *source.Source - impl driver.Pool + impl driver.Grip clnup *cleanup.Cleanup files *source.Files } -// DB implements driver.Pool. -func (p *pool) DB(ctx context.Context) (*sql.DB, error) { - return p.impl.DB(ctx) +// DB implements driver.Grip. +func (g *grip) DB(ctx context.Context) (*sql.DB, error) { + return g.impl.DB(ctx) } -// SQLDriver implements driver.Pool. -func (p *pool) SQLDriver() driver.SQLDriver { - return p.impl.SQLDriver() +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.impl.SQLDriver() } -// Source implements driver.Pool. -func (p *pool) Source() *source.Source { - return p.src +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src } -// TableMetadata implements driver.Pool. -func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { if tblName != source.MonotableName { return nil, errz.Errorf("table name should be %s for CSV/TSV etc., but got: %s", source.MonotableName, tblName) } - srcMeta, err := p.SourceMetadata(ctx, false) + srcMeta, err := g.SourceMetadata(ctx, false) if err != nil { return nil, err } @@ -193,23 +193,23 @@ func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Tab return srcMeta.Tables[0], nil } -// SourceMetadata implements driver.Pool. -func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - md, err := p.impl.SourceMetadata(ctx, noSchema) +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + md, err := g.impl.SourceMetadata(ctx, noSchema) if err != nil { return nil, err } - md.Handle = p.src.Handle - md.Location = p.src.Location - md.Driver = p.src.Type + md.Handle = g.src.Handle + md.Location = g.src.Location + md.Driver = g.src.Type - md.Name, err = source.LocationFileName(p.src) + md.Name, err = source.LocationFileName(g.src) if err != nil { return nil, err } - md.Size, err = p.files.Size(ctx, p.src) + md.Size, err = g.files.Size(ctx, g.src) if err != nil { return nil, err } @@ -218,9 +218,9 @@ func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return md, nil } -// Close implements driver.Pool. -func (p *pool) Close() error { - p.log.Debug(lgm.CloseDB, lga.Handle, p.src.Handle) +// Close implements driver.Grip. +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - return errz.Combine(p.impl.Close(), p.clnup.Run()) + return errz.Combine(g.impl.Close(), g.clnup.Run()) } diff --git a/drivers/mysql/metadata.go b/drivers/mysql/metadata.go index 90b73c6e3..fcbc3111a 100644 --- a/drivers/mysql/metadata.go +++ b/drivers/mysql/metadata.go @@ -170,7 +170,7 @@ func getNewRecordFunc(rowMeta record.Meta) driver.NewRecordFunc { } // getTableMetadata gets the metadata for a single table. It is the -// implementation of driver.Pool.Table. +// implementation of driver.Grip.Table. func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*metadata.Table, error) { query := `SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, TABLE_COMMENT, (DATA_LENGTH + INDEX_LENGTH) AS table_size, (SELECT COUNT(*) FROM ` + "`" + tblName + "`" + `) AS row_count @@ -248,7 +248,7 @@ ORDER BY cols.ordinal_position ASC` return cols, errw(rows.Err()) } -// getSourceMetadata is the implementation of driver.Pool.SourceMetadata. +// getSourceMetadata is the implementation of driver.Grip.SourceMetadata. // // Multiple queries are required to build the SourceMetadata, and this // impl makes use of errgroup to make concurrent queries. In the initial diff --git a/drivers/mysql/metadata_test.go b/drivers/mysql/metadata_test.go index 87cd31530..c5bff62ab 100644 --- a/drivers/mysql/metadata_test.go +++ b/drivers/mysql/metadata_test.go @@ -78,8 +78,8 @@ func TestDatabase_SourceMetadata_MySQL(t *testing.T) { t.Run(handle, func(t *testing.T) { t.Parallel() - th, _, _, pool, _ := testh.NewWith(t, handle) - md, err := pool.SourceMetadata(th.Context, false) + th, _, _, grip, _ := testh.NewWith(t, handle) + md, err := grip.SourceMetadata(th.Context, false) require.NoError(t, err) require.Equal(t, "sakila", md.Name) require.Equal(t, handle, md.Handle) @@ -101,8 +101,8 @@ func TestDatabase_TableMetadata(t *testing.T) { t.Run(handle, func(t *testing.T) { t.Parallel() - th, _, _, pool, _ := testh.NewWith(t, handle) - md, err := pool.TableMetadata(th.Context, sakila.TblActor) + th, _, _, grip, _ := testh.NewWith(t, handle) + md, err := grip.TableMetadata(th.Context, sakila.TblActor) require.NoError(t, err) require.Equal(t, sakila.TblActor, md.Name) }) diff --git a/drivers/mysql/mysql.go b/drivers/mysql/mysql.go index 3ddc97c54..462b5f6f9 100644 --- a/drivers/mysql/mysql.go +++ b/drivers/mysql/mysql.go @@ -459,8 +459,8 @@ func (d *driveri) getTableRecordMeta(ctx context.Context, db sqlz.DB, tblName st return destCols, nil } -// Open implements driver.PoolOpener. -func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { +// Open implements driver.GripOpener. +func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) db, err := d.doOpen(ctx, src) @@ -472,7 +472,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er return nil, err } - return &pool{log: d.log, db: db, src: src, drvr: d}, nil + return &grip{log: d.log, db: db, src: src, drvr: d}, nil } func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, error) { @@ -568,43 +568,43 @@ func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string, return beforeCount, errw(tx.Commit()) } -// pool implements driver.Pool. -type pool struct { +// grip implements driver.Grip. +type grip struct { log *slog.Logger db *sql.DB src *source.Source drvr *driveri } -// DB implements driver.Pool. -func (p *pool) DB(context.Context) (*sql.DB, error) { - return p.db, nil +// DB implements driver.Grip. +func (g *grip) DB(context.Context) (*sql.DB, error) { + return g.db, nil } -// SQLDriver implements driver.Pool. -func (p *pool) SQLDriver() driver.SQLDriver { - return p.drvr +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.drvr } -// Source implements driver.Pool. -func (p *pool) Source() *source.Source { - return p.src +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src } -// TableMetadata implements driver.Pool. -func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - return getTableMetadata(ctx, p.db, tblName) +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { + return getTableMetadata(ctx, g.db, tblName) } -// SourceMetadata implements driver.Pool. -func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - return getSourceMetadata(ctx, p.src, p.db, noSchema) +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + return getSourceMetadata(ctx, g.src, g.db, noSchema) } -// Close implements driver.Pool. -func (p *pool) Close() error { - p.log.Debug(lgm.CloseDB, lga.Handle, p.src.Handle) - return errw(p.db.Close()) +// Close implements driver.Grip. +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) + return errw(g.db.Close()) } // dsnFromLocation builds the mysql driver DSN from src.Location. diff --git a/drivers/postgres/postgres.go b/drivers/postgres/postgres.go index 1c803c985..39a98a12c 100644 --- a/drivers/postgres/postgres.go +++ b/drivers/postgres/postgres.go @@ -133,8 +133,8 @@ func (d *driveri) Renderer() *render.Renderer { return r } -// Open implements driver.PoolOpener. -func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { +// Open implements driver.GripOpener. +func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) db, err := d.doOpen(ctx, src) @@ -146,7 +146,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er return nil, err } - return &pool{log: d.log, db: db, src: src, drvr: d}, nil + return &grip{log: d.log, db: db, src: src, drvr: d}, nil } func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, error) { @@ -741,32 +741,32 @@ func (d *driveri) RecordMeta(ctx context.Context, colTypes []*sql.ColumnType) ( return recMeta, mungeFn, nil } -// pool is the postgres implementation of driver.Pool. -type pool struct { +// grip is the postgres implementation of driver.Grip. +type grip struct { log *slog.Logger drvr *driveri db *sql.DB src *source.Source } -// DB implements driver.Pool. -func (p *pool) DB(context.Context) (*sql.DB, error) { - return p.db, nil +// DB implements driver.Grip. +func (g *grip) DB(context.Context) (*sql.DB, error) { + return g.db, nil } -// SQLDriver implements driver.Pool. -func (p *pool) SQLDriver() driver.SQLDriver { - return p.drvr +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.drvr } -// Source implements driver.Pool. -func (p *pool) Source() *source.Source { - return p.src +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src } -// TableMetadata implements driver.Pool. -func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - db, err := p.DB(ctx) +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { + db, err := g.DB(ctx) if err != nil { return nil, err } @@ -774,20 +774,20 @@ func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Tab return getTableMetadata(ctx, db, tblName) } -// SourceMetadata implements driver.Pool. -func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - db, err := p.DB(ctx) +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + db, err := g.DB(ctx) if err != nil { return nil, err } - return getSourceMetadata(ctx, p.src, db, noSchema) + return getSourceMetadata(ctx, g.src, db, noSchema) } -// Close implements driver.Pool. -func (p *pool) Close() error { - p.log.Debug(lgm.CloseDB, lga.Handle, p.src.Handle) +// Close implements driver.Grip. +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - err := p.db.Close() + err := g.db.Close() if err != nil { return errw(err) } diff --git a/drivers/postgres/postgres_test.go b/drivers/postgres/postgres_test.go index 1d2024fe5..a7118c4f6 100644 --- a/drivers/postgres/postgres_test.go +++ b/drivers/postgres/postgres_test.go @@ -219,12 +219,12 @@ func TestAlternateSchema(t *testing.T) { src2 := src.Clone() src2.Handle += "2" src2.Location += "?search_path=" + schemaName - pool2 := th.Open(src2) - md2, err := pool2.SourceMetadata(ctx, false) + grip2 := th.Open(src2) + md2, err := grip2.SourceMetadata(ctx, false) require.NoError(t, err) require.Equal(t, schemaName, md2.Schema) - tblMeta2, err := pool2.TableMetadata(ctx, tblName) + tblMeta2, err := grip2.TableMetadata(ctx, tblName) require.NoError(t, err) require.Equal(t, int64(wantRowCount), tblMeta2.RowCount) } @@ -276,10 +276,10 @@ func BenchmarkDatabase_SourceMetadata(b *testing.B) { b.Run(handle, func(b *testing.B) { th := testh.New(b) th.Log = lg.Discard() - pool := th.Open(th.Source(handle)) + grip := th.Open(th.Source(handle)) b.ResetTimer() - md, err := pool.SourceMetadata(th.Context, false) + md, err := grip.SourceMetadata(th.Context, false) require.NoError(b, err) require.Equal(b, "sakila", md.Name) }) diff --git a/drivers/sqlite3/metadata_test.go b/drivers/sqlite3/metadata_test.go index b7a827cd1..7fd73364a 100644 --- a/drivers/sqlite3/metadata_test.go +++ b/drivers/sqlite3/metadata_test.go @@ -201,7 +201,7 @@ func TestRecordMetadata(t *testing.T) { t.Run(tc.tbl, func(t *testing.T) { t.Parallel() - th, _, drvr, pool, db := testh.NewWith(t, sakila.SL3) + th, _, drvr, grip, db := testh.NewWith(t, sakila.SL3) query := fmt.Sprintf("SELECT %s FROM %s", strings.Join(tc.colNames, ", "), tc.tbl) rows, err := db.QueryContext(th.Context, query) //nolint:rowserrcheck @@ -232,7 +232,7 @@ func TestRecordMetadata(t *testing.T) { } // Now check our table metadata - gotTblMeta, err := pool.TableMetadata(th.Context, tc.tbl) + gotTblMeta, err := grip.TableMetadata(th.Context, tc.tbl) require.NoError(t, err) require.Equal(t, tc.tbl, gotTblMeta.Name) require.Equal(t, tc.rowCount, gotTblMeta.RowCount) @@ -282,12 +282,12 @@ func TestAggregateFuncsQuery(t *testing.T) { func BenchmarkDatabase_SourceMetadata(b *testing.B) { const numTables = 1000 - th, src, drvr, pool, db := testh.NewWith(b, testsrc.MiscDB) + th, src, drvr, grip, db := testh.NewWith(b, testsrc.MiscDB) tblNames := createTypeTestTbls(th, src, numTables, true) b.ResetTimer() for n := 0; n < b.N; n++ { - srcMeta, err := pool.SourceMetadata(th.Context, false) + srcMeta, err := grip.SourceMetadata(th.Context, false) require.NoError(b, err) require.True(b, len(srcMeta.Tables) > len(tblNames)) } diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index efd169315..9bc0d9116 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -124,8 +124,8 @@ func (d *driveri) DriverMetadata() driver.Metadata { } } -// Open implements driver.PoolOpener. -func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { +// Open implements driver.GripOpener. +func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) db, err := d.doOpen(ctx, src) @@ -137,7 +137,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er return nil, err } - return &pool{log: d.log, db: db, src: src, drvr: d}, nil + return &grip{log: d.log, db: db, src: src, drvr: d}, nil } func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, error) { @@ -915,36 +915,37 @@ func (d *driveri) getTableRecordMeta(ctx context.Context, db sqlz.DB, tblName st return destCols, nil } -// pool implements driver.Pool. -type pool struct { +// grip implements driver.Grip. +type grip struct { log *slog.Logger db *sql.DB src *source.Source drvr *driveri - // DEBUG: closeMu and closed exist while debugging close behavior + // DEBUG: closeMu and closed exist while debugging close behavior. + // We should be able to get rid of them eventually. closeMu sync.Mutex closed bool } -// DB implements driver.Pool. -func (p *pool) DB(context.Context) (*sql.DB, error) { - return p.db, nil +// DB implements driver.Grip. +func (g *grip) DB(context.Context) (*sql.DB, error) { + return g.db, nil } -// SQLDriver implements driver.Pool. -func (p *pool) SQLDriver() driver.SQLDriver { - return p.drvr +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.drvr } -// Source implements driver.Pool. -func (p *pool) Source() *source.Source { - return p.src +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src } -// TableMetadata implements driver.Pool. -func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - db, err := p.DB(ctx) +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { + db, err := g.DB(ctx) if err != nil { return nil, err } @@ -952,20 +953,20 @@ func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Tab return getTableMetadata(ctx, db, tblName) } -// SourceMetadata implements driver.Pool. -func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { // https://stackoverflow.com/questions/9646353/how-to-find-sqlite-database-file-version - md := &metadata.Source{Handle: p.src.Handle, Driver: Type, DBDriver: dbDrvr} + md := &metadata.Source{Handle: g.src.Handle, Driver: Type, DBDriver: dbDrvr} - dsn, err := PathFromLocation(p.src) + dsn, err := PathFromLocation(g.src) if err != nil { return nil, err } const q = "SELECT sqlite_version(), (SELECT name FROM pragma_database_list ORDER BY seq limit 1);" - err = p.db.QueryRowContext(ctx, q).Scan(&md.DBVersion, &md.Schema) + err = g.db.QueryRowContext(ctx, q).Scan(&md.DBVersion, &md.Schema) if err != nil { return nil, errw(err) } @@ -982,9 +983,9 @@ func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou md.FQName = fi.Name() + "." + md.Schema // SQLite doesn't support catalog, but we conventionally set it to "default" md.Catalog = "default" - md.Location = p.src.Location + md.Location = g.src.Location - md.DBProperties, err = getDBProperties(ctx, p.db) + md.DBProperties, err = getDBProperties(ctx, g.db) if err != nil { return nil, err } @@ -993,7 +994,7 @@ func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return md, nil } - md.Tables, err = getAllTableMetadata(ctx, p.db, md.Schema) + md.Tables, err = getAllTableMetadata(ctx, g.db, md.Schema) if err != nil { return nil, err } @@ -1009,19 +1010,19 @@ func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return md, nil } -// Close implements driver.Pool. -func (p *pool) Close() error { - p.closeMu.Lock() - defer p.closeMu.Unlock() +// Close implements driver.Grip. +func (g *grip) Close() error { + g.closeMu.Lock() + defer g.closeMu.Unlock() - if p.closed { - p.log.Warn("SQLite DB already closed", lga.Src, p.src) + if g.closed { + g.log.Warn("SQLite DB already closed", lga.Src, g.src) return nil } - p.log.Debug(lgm.CloseDB, lga.Handle, p.src.Handle) - err := errw(p.db.Close()) - p.closed = true + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) + err := errw(g.db.Close()) + g.closed = true return err } diff --git a/drivers/sqlserver/sqlserver.go b/drivers/sqlserver/sqlserver.go index 0503a0109..6e7ac72d9 100644 --- a/drivers/sqlserver/sqlserver.go +++ b/drivers/sqlserver/sqlserver.go @@ -156,8 +156,8 @@ func (d *driveri) Renderer() *render.Renderer { return r } -// Open implements driver.PoolOpener. -func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { +// Open implements driver.GripOpener. +func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) db, err := d.doOpen(ctx, src) @@ -169,7 +169,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, er return nil, err } - return &pool{log: d.log, db: db, src: src, drvr: d}, nil + return &grip{log: d.log, db: db, src: src, drvr: d}, nil } func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, error) { @@ -670,58 +670,58 @@ func (d *driveri) getTableColsMeta(ctx context.Context, db sqlz.DB, tblName stri return destCols, nil } -// pool implements driver.Pool. -type pool struct { +// grip implements driver.Grip. +type grip struct { log *slog.Logger drvr *driveri db *sql.DB src *source.Source } -var _ driver.Pool = (*pool)(nil) +var _ driver.Grip = (*grip)(nil) -// DB implements driver.Pool. -func (d *pool) DB(context.Context) (*sql.DB, error) { - return d.db, nil +// DB implements driver.Grip. +func (g *grip) DB(context.Context) (*sql.DB, error) { + return g.db, nil } -// SQLDriver implements driver.Pool. -func (d *pool) SQLDriver() driver.SQLDriver { - return d.drvr +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.drvr } -// Source implements driver.Pool. -func (d *pool) Source() *source.Source { - return d.src +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src } -// TableMetadata implements driver.Pool. -func (d *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { const query = `SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_TYPE FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = @p1` var catalog, schema, tblType string - err := d.db.QueryRowContext(ctx, query, tblName).Scan(&catalog, &schema, &tblType) + err := g.db.QueryRowContext(ctx, query, tblName).Scan(&catalog, &schema, &tblType) if err != nil { return nil, errw(err) } // TODO: getTableMetadata can cause deadlock in the DB. Needs further investigation. // But a quick hack would be to use retry on a deadlock error. - return getTableMetadata(ctx, d.db, catalog, schema, tblName, tblType) + return getTableMetadata(ctx, g.db, catalog, schema, tblName, tblType) } -// SourceMetadata implements driver.Pool. -func (d *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - return getSourceMetadata(ctx, d.src, d.db, noSchema) +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + return getSourceMetadata(ctx, g.src, g.db, noSchema) } -// Close implements driver.Pool. -func (d *pool) Close() error { - d.log.Debug(lgm.CloseDB, lga.Handle, d.src.Handle) +// Close implements driver.Grip. +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - return errw(d.db.Close()) + return errw(g.db.Close()) } // newStmtExecFunc returns a StmtExecFunc that has logic to deal with diff --git a/drivers/userdriver/userdriver.go b/drivers/userdriver/userdriver.go index fef9e6abe..47839cef4 100644 --- a/drivers/userdriver/userdriver.go +++ b/drivers/userdriver/userdriver.go @@ -23,9 +23,9 @@ import ( ) // ImportFunc is a function that can import -// data (as defined in def) to destPool. +// data (as defined in def) to destGrip. type ImportFunc func(ctx context.Context, def *DriverDef, - data io.Reader, destPool driver.Pool) error + data io.Reader, destGrip driver.Grip) error // Provider implements driver.Provider for a DriverDef. type Provider struct { @@ -80,32 +80,32 @@ func (d *driveri) DriverMetadata() driver.Metadata { } } -// Open implements driver.PoolOpener. -func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { +// Open implements driver.GripOpener. +func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { log := lg.FromContext(ctx).With(lga.Src, src) log.Debug(lgm.OpenSrc) - p := &pool{ + g := &grip{ log: d.log, src: src, } allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) - ingestFn := func(ctx context.Context, destPool driver.Pool) error { + ingestFn := func(ctx context.Context, destGrip driver.Grip) error { r, err := d.files.Open(ctx, src) if err != nil { return err } defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - return d.ingestFn(ctx, d.def, r, destPool) + return d.ingestFn(ctx, d.def, r, destGrip) } var err error - if p.impl, err = d.ingester.OpenIngest(ctx, src, allowCache, ingestFn); err != nil { + if g.impl, err = d.ingester.OpenIngest(ctx, src, allowCache, ingestFn); err != nil { return nil, err } - return p, nil + return g, nil } // Truncate implements driver.Driver. @@ -140,43 +140,43 @@ func (d *driveri) Ping(ctx context.Context, src *source.Source) error { return r.Close() } -// pool implements driver.Pool. -type pool struct { +// grip implements driver.Grip. +type grip struct { log *slog.Logger src *source.Source - impl driver.Pool + impl driver.Grip } -// DB implements driver.Pool. -func (d *pool) DB(ctx context.Context) (*sql.DB, error) { - return d.impl.DB(ctx) +// DB implements driver.Grip. +func (g *grip) DB(ctx context.Context) (*sql.DB, error) { + return g.impl.DB(ctx) } -// SQLDriver implements driver.Pool. -func (d *pool) SQLDriver() driver.SQLDriver { - return d.impl.SQLDriver() +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.impl.SQLDriver() } -// Source implements driver.Pool. -func (d *pool) Source() *source.Source { - return d.src +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src } -// TableMetadata implements driver.Pool. -func (d *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - return d.impl.TableMetadata(ctx, tblName) +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { + return g.impl.TableMetadata(ctx, tblName) } -// SourceMetadata implements driver.Pool. -func (d *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - meta, err := d.impl.SourceMetadata(ctx, noSchema) +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + meta, err := g.impl.SourceMetadata(ctx, noSchema) if err != nil { return nil, err } - meta.Handle = d.src.Handle - meta.Location = d.src.Location - meta.Name, err = source.LocationFileName(d.src) + meta.Handle = g.src.Handle + meta.Location = g.src.Location + meta.Name, err = source.LocationFileName(g.src) if err != nil { return nil, err } @@ -185,9 +185,9 @@ func (d *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return meta, nil } -// Close implements driver.Pool. -func (d *pool) Close() error { - d.log.Debug(lgm.CloseDB, lga.Handle, d.src.Handle) +// Close implements driver.Grip. +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - return d.impl.Close() + return g.impl.Close() } diff --git a/drivers/userdriver/userdriver_test.go b/drivers/userdriver/userdriver_test.go index 2525bac7c..c04d13e37 100644 --- a/drivers/userdriver/userdriver_test.go +++ b/drivers/userdriver/userdriver_test.go @@ -34,13 +34,13 @@ func TestDriver(t *testing.T) { th := testh.New(t, testh.OptLongOpen()) src := th.Source(tc.handle) - pool := th.Open(src) + grip := th.Open(src) - srcMeta, err := pool.SourceMetadata(th.Context, false) + srcMeta, err := grip.SourceMetadata(th.Context, false) require.NoError(t, err) require.True(t, stringz.InSlice(srcMeta.TableNames(), tc.tbl)) - tblMeta, err := pool.TableMetadata(th.Context, tc.tbl) + tblMeta, err := grip.TableMetadata(th.Context, tc.tbl) require.NoError(t, err) require.Equal(t, tc.tbl, tblMeta.Name) diff --git a/drivers/userdriver/xmlud/xmlimport.go b/drivers/userdriver/xmlud/xmlimport.go index ecd7b699a..5d1cc8888 100644 --- a/drivers/userdriver/xmlud/xmlimport.go +++ b/drivers/userdriver/xmlud/xmlimport.go @@ -27,7 +27,7 @@ import ( const Genre = "xml" // Import implements userdriver.ImportFunc. -func Import(ctx context.Context, def *userdriver.DriverDef, data io.Reader, destPool driver.Pool) error { +func Import(ctx context.Context, def *userdriver.DriverDef, data io.Reader, destGrip driver.Grip) error { if def.Genre != Genre { return errz.Errorf("xmlud.Import does not support genre {%s}", def.Genre) } @@ -45,7 +45,7 @@ func Import(ctx context.Context, def *userdriver.DriverDef, data io.Reader, dest msgOnce: map[string]struct{}{}, } - err := im.execImport(ctx, data, destPool) + err := im.execImport(ctx, data, destGrip) err2 := im.clnup.Run() if err != nil { return errz.Wrap(err, "xml import") @@ -59,7 +59,7 @@ type importer struct { log *slog.Logger def *userdriver.DriverDef data io.Reader - destPool driver.Pool + destGrip driver.Grip selStack *selStack rowStack *rowStack tblDefs map[string]*sqlmodel.TableDef @@ -86,8 +86,8 @@ type importer struct { msgOnce map[string]struct{} } -func (im *importer) execImport(ctx context.Context, r io.Reader, destPool driver.Pool) error { //nolint:gocognit - im.data, im.destPool = r, destPool +func (im *importer) execImport(ctx context.Context, r io.Reader, destGrip driver.Grip) error { //nolint:gocognit + im.data, im.destGrip = r, destGrip err := im.createTables(ctx) if err != nil { @@ -429,13 +429,13 @@ func (im *importer) dbInsert(ctx context.Context, row *rowState) error { execInsertFn, ok := im.execInsertFns[cacheKey] if !ok { - db, err := im.destPool.DB(ctx) + db, err := im.destGrip.DB(ctx) if err != nil { return err } // Nothing cached, prepare the insert statement and insert munge func - stmtExecer, err := im.destPool.SQLDriver().PrepareInsertStmt(ctx, db, tblName, colNames, 1) + stmtExecer, err := im.destGrip.SQLDriver().PrepareInsertStmt(ctx, db, tblName, colNames, 1) if err != nil { return err } @@ -469,7 +469,7 @@ func (im *importer) dbInsert(ctx context.Context, row *rowState) error { // dbUpdate updates row's table with row's dirty values, using row's // primary key cols as the args to the WHERE clause. func (im *importer) dbUpdate(ctx context.Context, row *rowState) error { - drvr := im.destPool.SQLDriver() + drvr := im.destGrip.SQLDriver() tblName := row.tbl.Name pkColNames := row.tbl.PrimaryKey @@ -506,7 +506,7 @@ func (im *importer) dbUpdate(ctx context.Context, row *rowState) error { cacheKey := "##update_func__" + tblName + "__" + strings.Join(colNames, ",") + whereClause execUpdateFn, ok := im.execUpdateFns[cacheKey] if !ok { - db, err := im.destPool.DB(ctx) + db, err := im.destGrip.DB(ctx) if err != nil { return err } @@ -576,16 +576,16 @@ func (im *importer) createTables(ctx context.Context) error { im.tblDefs[tblDef.Name] = tblDef - db, err := im.destPool.DB(ctx) + db, err := im.destGrip.DB(ctx) if err != nil { return err } - err = im.destPool.SQLDriver().CreateTable(ctx, db, tblDef) + err = im.destGrip.SQLDriver().CreateTable(ctx, db, tblDef) if err != nil { return err } - im.log.Debug("Created table", lga.Target, source.Target(im.destPool.Source(), tblDef.Name)) + im.log.Debug("Created table", lga.Target, source.Target(im.destGrip.Source(), tblDef.Name)) } return nil diff --git a/drivers/xlsx/grip.go b/drivers/xlsx/grip.go new file mode 100644 index 000000000..1df7b0b67 --- /dev/null +++ b/drivers/xlsx/grip.go @@ -0,0 +1,73 @@ +package xlsx + +import ( + "context" + "database/sql" + "log/slog" + + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/driver" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/metadata" +) + +// grip implements driver.Grip. It implements a deferred ingest +// of the Excel data. +type grip struct { + // REVISIT: do we need grip.log, or can we use lg.FromContext? + log *slog.Logger + + src *source.Source + files *source.Files + dbGrip driver.Grip +} + +// DB implements driver.Grip. +func (g *grip) DB(ctx context.Context) (*sql.DB, error) { + return g.dbGrip.DB(ctx) +} + +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.dbGrip.SQLDriver() +} + +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src +} + +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + md, err := g.dbGrip.SourceMetadata(ctx, noSchema) + if err != nil { + return nil, err + } + + md.Handle = g.src.Handle + md.Driver = Type + md.Location = g.src.Location + if md.Name, err = source.LocationFileName(g.src); err != nil { + return nil, err + } + md.FQName = md.Name + + if md.Size, err = g.files.Size(ctx, g.src); err != nil { + return nil, err + } + + return md, nil +} + +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { + return g.dbGrip.TableMetadata(ctx, tblName) +} + +// Close implements driver.Grip. +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) + + return g.dbGrip.Close() +} diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 83143c9d9..41487e2bd 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -84,16 +84,17 @@ func (xs *xSheet) loadSampleRows(ctx context.Context, sampleSize int) error { return nil } -// ingestXLSX loads the data in xfile into destPool. +// ingestXLSX loads the data in xfile into destGrip. // If includeSheetNames is non-empty, only the named sheets are ingested. -func ingestXLSX(ctx context.Context, src *source.Source, destPool driver.Pool, +func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, xfile *excelize.File, includeSheetNames []string, ) error { + // FIXME: delete includeSheetNames log := lg.FromContext(ctx) start := time.Now() log.Debug("Beginning import from XLSX", lga.Src, src, - lga.Target, destPool.Source()) + lga.Target, destGrip.Source()) var sheets []*xSheet if len(includeSheetNames) > 0 { @@ -124,18 +125,18 @@ func ingestXLSX(ctx context.Context, src *source.Source, destPool driver.Pool, } var db *sql.DB - if db, err = destPool.DB(ctx); err != nil { + if db, err = destGrip.DB(ctx); err != nil { return err } - if err = destPool.SQLDriver().CreateTable(ctx, db, sheetTbl.def); err != nil { + if err = destGrip.SQLDriver().CreateTable(ctx, db, sheetTbl.def); err != nil { return err } } log.Debug("Tables created (but not yet populated)", lga.Count, len(sheetTbls), - lga.Target, destPool.Source(), + lga.Target, destGrip.Source(), lga.Elapsed, time.Since(start)) var imported, skipped int @@ -146,7 +147,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, destPool driver.Pool, continue } - if err = ingestSheetToTable(ctx, destPool, sheetTbls[i]); err != nil { + if err = ingestSheetToTable(ctx, destGrip, sheetTbls[i]); err != nil { return err } imported++ @@ -156,7 +157,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, destPool driver.Pool, lga.Count, imported, "skipped", skipped, lga.From, src, - lga.To, destPool.Source(), + lga.To, destGrip.Source(), lga.Elapsed, time.Since(start), ) @@ -164,8 +165,8 @@ func ingestXLSX(ctx context.Context, src *source.Source, destPool driver.Pool, } // ingestSheetToTable imports the sheet data into the appropriate table -// in scratchPool. The scratch table must already exist. -func ingestSheetToTable(ctx context.Context, destPool driver.Pool, sheetTbl *sheetTable) error { +// in destGrip. The scratch table must already exist. +func ingestSheetToTable(ctx context.Context, destGrip driver.Grip, sheetTbl *sheetTable) error { var ( log = lg.FromContext(ctx) startTime = time.Now() @@ -175,7 +176,7 @@ func ingestSheetToTable(ctx context.Context, destPool driver.Pool, sheetTbl *she destColKinds = tblDef.ColKinds() ) - db, err := destPool.DB(ctx) + db, err := destGrip.DB(ctx) if err != nil { return err } @@ -186,7 +187,7 @@ func ingestSheetToTable(ctx context.Context, destPool driver.Pool, sheetTbl *she } defer lg.WarnIfCloseError(log, lgm.CloseDB, conn) - drvr := destPool.SQLDriver() + drvr := destGrip.SQLDriver() batchSize := driver.MaxBatchRows(drvr, len(destColKinds)) bi, err := driver.NewBatchInsert( @@ -264,7 +265,7 @@ func ingestSheetToTable(ctx context.Context, destPool driver.Pool, sheetTbl *she log.Debug("Inserted rows from sheet into table", lga.Count, bi.Written(), laSheet, sheet.name, - lga.Target, source.Target(destPool.Source(), tblDef.Name), + lga.Target, source.Target(destGrip.Source(), tblDef.Name), lga.Elapsed, time.Since(startTime)) return nil diff --git a/drivers/xlsx/pool.go b/drivers/xlsx/pool.go deleted file mode 100644 index b3f398089..000000000 --- a/drivers/xlsx/pool.go +++ /dev/null @@ -1,73 +0,0 @@ -package xlsx - -import ( - "context" - "database/sql" - "log/slog" - - "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/lg/lgm" - "github.com/neilotoole/sq/libsq/driver" - "github.com/neilotoole/sq/libsq/source" - "github.com/neilotoole/sq/libsq/source/metadata" -) - -// pool implements driver.Pool. It implements a deferred ingest -// of the Excel data. -type pool struct { - // REVISIT: do we need pool.log, or can we use lg.FromContext? - log *slog.Logger - - src *source.Source - files *source.Files - backingPool driver.Pool -} - -// DB implements driver.Pool. -func (p *pool) DB(ctx context.Context) (*sql.DB, error) { - return p.backingPool.DB(ctx) -} - -// SQLDriver implements driver.Pool. -func (p *pool) SQLDriver() driver.SQLDriver { - return p.backingPool.SQLDriver() -} - -// Source implements driver.Pool. -func (p *pool) Source() *source.Source { - return p.src -} - -// SourceMetadata implements driver.Pool. -func (p *pool) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - md, err := p.backingPool.SourceMetadata(ctx, noSchema) - if err != nil { - return nil, err - } - - md.Handle = p.src.Handle - md.Driver = Type - md.Location = p.src.Location - if md.Name, err = source.LocationFileName(p.src); err != nil { - return nil, err - } - md.FQName = md.Name - - if md.Size, err = p.files.Size(ctx, p.src); err != nil { - return nil, err - } - - return md, nil -} - -// TableMetadata implements driver.Pool. -func (p *pool) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - return p.backingPool.TableMetadata(ctx, tblName) -} - -// Close implements driver.Pool. -func (p *pool) Close() error { - p.log.Debug(lgm.CloseDB, lga.Handle, p.src.Handle) - - return p.backingPool.Close() -} diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index b7512de5d..e8535b6e9 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -57,12 +57,12 @@ func (d *Driver) DriverMetadata() driver.Metadata { } } -// Open implements driver.PoolOpener. -func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Pool, error) { +// Open implements driver.GripOpener. +func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { log := lg.FromContext(ctx).With(lga.Src, src) log.Debug(lgm.OpenSrc, lga.Src, src) - p := &pool{ + p := &grip{ log: log, src: src, files: d.files, @@ -70,7 +70,7 @@ func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Pool, err allowCache := driver.OptIngestCache.Get(options.FromContext(ctx)) - ingestFn := func(ctx context.Context, destPool driver.Pool) error { + ingestFn := func(ctx context.Context, destGrip driver.Grip) error { log.Debug("Ingest XLSX", lga.Src, p.src) r, err := p.files.Open(ctx, p.src) if err != nil { @@ -85,14 +85,14 @@ func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Pool, err defer lg.WarnIfCloseError(log, lgm.CloseFileReader, xfile) - if err = ingestXLSX(ctx, p.src, destPool, xfile, nil); err != nil { + if err = ingestXLSX(ctx, p.src, destGrip, xfile, nil); err != nil { return err } return nil } var err error - if p.backingPool, err = d.ingester.OpenIngest(ctx, p.src, allowCache, ingestFn); err != nil { + if p.dbGrip, err = d.ingester.OpenIngest(ctx, p.src, allowCache, ingestFn); err != nil { return nil, err } diff --git a/drivers/xlsx/xlsx_test.go b/drivers/xlsx/xlsx_test.go index 5c0d393c5..50ef6b069 100644 --- a/drivers/xlsx/xlsx_test.go +++ b/drivers/xlsx/xlsx_test.go @@ -162,13 +162,13 @@ func TestOpenFileFormats(t *testing.T) { Location: filepath.Join("testdata", "file_formats", tc.filename), }) - pool, err := th.Sources().Open(th.Context, src) + grip, err := th.Sources().Open(th.Context, src) if tc.wantErr { require.Error(t, err) return } require.NoError(t, err) - db, err := pool.DB(th.Context) + db, err := grip.DB(th.Context) require.NoError(t, err) err = db.PingContext(th.Context) require.NoError(t, err) diff --git a/libsq/dbwriter.go b/libsq/dbwriter.go index 7fe60ed85..3aea29b93 100644 --- a/libsq/dbwriter.go +++ b/libsq/dbwriter.go @@ -26,7 +26,7 @@ type DBWriter struct { msg string wg *sync.WaitGroup cancelFn context.CancelFunc - destPool driver.Pool + destGrip driver.Grip destTbl string recordCh chan record.Record bi *driver.BatchInsert @@ -42,17 +42,17 @@ type DBWriter struct { // DBWriterPreWriteHook is a function that is invoked before DBWriter // begins writing. -type DBWriterPreWriteHook func(ctx context.Context, recMeta record.Meta, destPool driver.Pool, tx sqlz.DB) error +type DBWriterPreWriteHook func(ctx context.Context, recMeta record.Meta, destGrip driver.Grip, tx sqlz.DB) error // DBWriterCreateTableIfNotExistsHook returns a hook that // creates destTblName if it does not exist. func DBWriterCreateTableIfNotExistsHook(destTblName string) DBWriterPreWriteHook { - return func(ctx context.Context, recMeta record.Meta, destPool driver.Pool, tx sqlz.DB) error { - db, err := destPool.DB(ctx) + return func(ctx context.Context, recMeta record.Meta, destGrip driver.Grip, tx sqlz.DB) error { + db, err := destGrip.DB(ctx) if err != nil { return err } - tblExists, err := destPool.SQLDriver().TableExists(ctx, db, destTblName) + tblExists, err := destGrip.SQLDriver().TableExists(ctx, db, destTblName) if err != nil { return errz.Err(err) } @@ -65,9 +65,9 @@ func DBWriterCreateTableIfNotExistsHook(destTblName string) DBWriterPreWriteHook destColKinds := recMeta.Kinds() destTblDef := sqlmodel.NewTableDef(destTblName, destColNames, destColKinds) - err = destPool.SQLDriver().CreateTable(ctx, tx, destTblDef) + err = destGrip.SQLDriver().CreateTable(ctx, tx, destTblDef) if err != nil { - return errz.Wrapf(err, "failed to create dest table %s.%s", destPool.Source().Handle, destTblName) + return errz.Wrapf(err, "failed to create dest table %s.%s", destGrip.Source().Handle, destTblName) } return nil @@ -76,14 +76,14 @@ func DBWriterCreateTableIfNotExistsHook(destTblName string) DBWriterPreWriteHook // NewDBWriter returns a new writer than implements RecordWriter. // The writer writes records from recordCh to destTbl -// in destPool. The recChSize param controls the size of recordCh +// in destGrip. The recChSize param controls the size of recordCh // returned by the writer's Open method. -func NewDBWriter(msg string, destPool driver.Pool, destTbl string, recChSize int, +func NewDBWriter(msg string, destGrip driver.Grip, destTbl string, recChSize int, preWriteHooks ...DBWriterPreWriteHook, ) *DBWriter { return &DBWriter{ msg: msg, - destPool: destPool, + destGrip: destGrip, destTbl: destTbl, recordCh: make(chan record.Record, recChSize), errCh: make(chan error, 3), @@ -103,7 +103,7 @@ func (w *DBWriter) Open(ctx context.Context, cancelFn context.CancelFunc, recMet ) { w.cancelFn = cancelFn - db, err := w.destPool.DB(ctx) + db, err := w.destGrip.DB(ctx) if err != nil { return nil, nil, err } @@ -111,22 +111,22 @@ func (w *DBWriter) Open(ctx context.Context, cancelFn context.CancelFunc, recMet // REVISIT: tx could potentially be passed to NewDBWriter? tx, err := db.BeginTx(ctx, nil) if err != nil { - return nil, nil, errz.Wrapf(err, "failed to open tx for %s.%s", w.destPool.Source().Handle, w.destTbl) + return nil, nil, errz.Wrapf(err, "failed to open tx for %s.%s", w.destGrip.Source().Handle, w.destTbl) } for _, hook := range w.preWriteHooks { - err = hook(ctx, recMeta, w.destPool, tx) + err = hook(ctx, recMeta, w.destGrip, tx) if err != nil { w.rollback(ctx, tx, err) return nil, nil, err } } - batchSize := driver.MaxBatchRows(w.destPool.SQLDriver(), len(recMeta.Names())) + batchSize := driver.MaxBatchRows(w.destGrip.SQLDriver(), len(recMeta.Names())) w.bi, err = driver.NewBatchInsert( ctx, w.msg, - w.destPool.SQLDriver(), + w.destGrip.SQLDriver(), tx, w.destTbl, recMeta.Names(), @@ -178,7 +178,7 @@ func (w *DBWriter) Open(ctx context.Context, cancelFn context.CancelFunc, recMet w.addErrs(commitErr) } else { lg.FromContext(ctx).Debug("Tx commit success", - lga.Target, source.Target(w.destPool.Source(), w.destTbl)) + lga.Target, source.Target(w.destGrip.Source(), w.destTbl)) } return @@ -233,7 +233,7 @@ func (w *DBWriter) addErrs(errs ...error) { func (w *DBWriter) rollback(ctx context.Context, tx *sql.Tx, causeErrs ...error) { // Guaranteed to be at least one causeErr lg.FromContext(ctx).Error("failed to insert data: tx will rollback", - lga.Target, w.destPool.Source().Handle+"."+w.destTbl, + lga.Target, w.destGrip.Source().Handle+"."+w.destTbl, lga.Err, causeErrs[0]) rollbackErr := errz.Err(tx.Rollback()) diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index ce41d7384..14b15e5a0 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -155,26 +155,26 @@ type Provider interface { DriverFor(typ drivertype.Type) (Driver, error) } -// PoolOpener opens a Pool. -type PoolOpener interface { - // Open returns a Pool instance for src. - Open(ctx context.Context, src *source.Source) (Pool, error) +// GripOpener opens a Grip. +type GripOpener interface { + // Open returns a Grip instance for src. + Open(ctx context.Context, src *source.Source) (Grip, error) } -// IngestOpener opens a pool for ingest use. +// IngestOpener opens a Grip for ingest use. type IngestOpener interface { - // OpenIngest opens a pool for src by executing ingestFn, which is + // OpenIngest opens a Grip for src by executing ingestFn, which is // responsible for ingesting data into dest. If allowCache is false, // ingest always occurs; if true, the cache is consulted first (and // ingestFn may not be invoked). OpenIngest(ctx context.Context, src *source.Source, allowCache bool, - ingestFn func(ctx context.Context, dest Pool) error) (Pool, error) + ingestFn func(ctx context.Context, dest Grip) error) (Grip, error) } // Driver is the core interface that must be implemented for each type // of data source. type Driver interface { - PoolOpener + GripOpener // DriverMetadata returns driver metadata. DriverMetadata() Metadata @@ -328,40 +328,35 @@ type SQLDriver interface { DBProperties(ctx context.Context, db sqlz.DB) (map[string]any, error) } -// Pool models a database handle representing a pool of underlying -// connections. It is conceptually equivalent to -// stdlib sql.DB, and in fact encapsulates a sql.DB instance. The -// realized sql.DB instance can be accessed via the DB method. +// Grip is the link between a source and its database connection. +// Why is it named Grip? TLDR: all the other names were taken, +// including Handle, Conn, DB, Source, etc. // -// REVISIT: Rename Pool to Grip or some such? -type Pool interface { - // DB returns the sql.DB object for this Pool. - // This operation can take a long time if opening the DB requires - // an ingest of data. - // For example, with file-based sources such as XLSX, invoking Open - // will ultimately read and import all CSV rows from the file. - // Thus, set a timeout on ctx as appropriate for the source. +// Grip is conceptually equivalent to stdlib sql.DB, and in fact +// encapsulates a sql.DB instance. The realized sql.DB instance can be +// accessed via the DB method. +type Grip interface { + // DB returns the sql.DB object for this Grip. + // This operation may take a long time if opening the DB requires + // an ingest of data (but note that when an ingest step occurs is + // driver-dependent). DB(ctx context.Context) (*sql.DB, error) // SQLDriver returns the underlying database driver. The type of the SQLDriver // may be different from the driver type reported by the Source. SQLDriver() SQLDriver - // Source returns the data source for which this connection was opened. + // FIXME: Add a method: SourceDriver() Driver. + + // Source returns the source for which this Grip was opened. Source() *source.Source - // SourceMetadata returns metadata about the data source. + // SourceMetadata returns metadata about the Grip. // If noSchema is true, schema details are not populated // on the returned metadata.Source. - // - // TODO: SourceMetadata doesn't really belong on driver.Pool? It - // should be moved to driver.Driver? SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) - // TableMetadata returns metadata for the specified table in the data source. - // - // TODO: TableMetadata doesn't really belong on driver.Pool? It - // should be moved to driver.Driver? + // TableMetadata returns metadata for the specified table in the Grip. TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) // Close is invoked to close and release any underlying resources. diff --git a/libsq/driver/driver_test.go b/libsq/driver/driver_test.go index bccf7db81..26a25c17e 100644 --- a/libsq/driver/driver_test.go +++ b/libsq/driver/driver_test.go @@ -266,12 +266,12 @@ func TestDriver_Open(t *testing.T) { th := testh.New(t) src := th.Source(handle) drvr := th.DriverFor(src) - pool, err := drvr.Open(th.Context, src) + grip, err := drvr.Open(th.Context, src) require.NoError(t, err) - db, err := pool.DB(th.Context) + db, err := grip.DB(th.Context) require.NoError(t, err) require.NoError(t, db.PingContext(th.Context)) - require.NoError(t, pool.Close()) + require.NoError(t, grip.Close()) }) } } @@ -443,9 +443,9 @@ func TestDatabase_TableMetadata(t *testing.T) { //nolint:tparallel t.Run(handle, func(t *testing.T) { t.Parallel() - th, _, _, pool, _ := testh.NewWith(t, handle) + th, _, _, grip, _ := testh.NewWith(t, handle) - tblMeta, err := pool.TableMetadata(th.Context, sakila.TblActor) + tblMeta, err := grip.TableMetadata(th.Context, sakila.TblActor) require.NoError(t, err) require.Equal(t, sakila.TblActor, tblMeta.Name) require.Equal(t, int64(sakila.TblActorCount), tblMeta.RowCount) @@ -462,9 +462,9 @@ func TestDatabase_SourceMetadata(t *testing.T) { t.Run(handle, func(t *testing.T) { t.Parallel() - th, _, _, pool, _ := testh.NewWith(t, handle) + th, _, _, grip, _ := testh.NewWith(t, handle) - md, err := pool.SourceMetadata(th.Context, false) + md, err := grip.SourceMetadata(th.Context, false) require.NoError(t, err) require.Equal(t, sakila.TblActor, md.Tables[0].Name) require.Equal(t, int64(sakila.TblActorCount), md.Tables[0].RowCount) @@ -484,11 +484,11 @@ func TestDatabase_SourceMetadata_concurrent(t *testing.T) { //nolint:tparallel t.Run(handle, func(t *testing.T) { t.Parallel() - th, _, _, pool, _ := testh.NewWith(t, handle) + th, _, _, grip, _ := testh.NewWith(t, handle) g, gCtx := errgroup.WithContext(th.Context) for i := 0; i < concurrency; i++ { g.Go(func() error { - md, err := pool.SourceMetadata(gCtx, false) + md, err := grip.SourceMetadata(gCtx, false) require.NoError(t, err) require.NotNil(t, md) gotTbl := md.Table(sakila.TblActor) @@ -541,7 +541,7 @@ func TestSQLDriver_AlterTableRename(t *testing.T) { handle := handle t.Run(handle, func(t *testing.T) { - th, src, drvr, pool, db := testh.NewWith(t, handle) + th, src, drvr, grip, db := testh.NewWith(t, handle) // Make a copy of the table to play with tbl := th.CopyTable(true, src, tablefq.From(sakila.TblActor), tablefq.T{}, true) @@ -552,7 +552,7 @@ func TestSQLDriver_AlterTableRename(t *testing.T) { require.NoError(t, err) defer th.DropTable(src, tablefq.From(newName)) - md, err := pool.TableMetadata(th.Context, newName) + md, err := grip.TableMetadata(th.Context, newName) require.NoError(t, err) require.Equal(t, newName, md.Name) sink, err := th.QuerySQL(src, nil, "SELECT * FROM "+newName) @@ -569,7 +569,7 @@ func TestSQLDriver_AlterTableRenameColumn(t *testing.T) { handle := handle t.Run(handle, func(t *testing.T) { - th, src, drvr, pool, db := testh.NewWith(t, handle) + th, src, drvr, grip, db := testh.NewWith(t, handle) // Make a copy of the table to play with tbl := th.CopyTable(true, src, tablefq.From(sakila.TblActor), tablefq.T{}, true) @@ -578,7 +578,7 @@ func TestSQLDriver_AlterTableRenameColumn(t *testing.T) { err := drvr.AlterTableRenameColumn(th.Context, db, tbl, "first_name", newName) require.NoError(t, err) - md, err := pool.TableMetadata(th.Context, tbl) + md, err := grip.TableMetadata(th.Context, tbl) require.NoError(t, err) require.NotNil(t, md.Column(newName)) sink, err := th.QuerySQL(src, nil, fmt.Sprintf("SELECT %s FROM %s", newName, tbl)) @@ -626,13 +626,13 @@ func TestSQLDriver_CurrentSchemaCatalog(t *testing.T) { tc := tc t.Run(tc.handle, func(t *testing.T) { - th, _, drvr, pool, db := testh.NewWith(t, tc.handle) + th, _, drvr, grip, db := testh.NewWith(t, tc.handle) gotSchema, err := drvr.CurrentSchema(th.Context, db) require.NoError(t, err) require.Equal(t, tc.wantSchema, gotSchema) - md, err := pool.SourceMetadata(th.Context, false) + md, err := grip.SourceMetadata(th.Context, false) require.NoError(t, err) require.NotNil(t, md) require.Equal(t, md.Schema, tc.wantSchema) diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 44ced2f31..74106a5b7 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -24,13 +24,13 @@ import ( "github.com/neilotoole/sq/libsq/source" ) -var _ PoolOpener = (*Sources)(nil) +var _ GripOpener = (*Sources)(nil) // ScratchSrcFunc is a function that returns a scratch source. // The caller is responsible for invoking cleanFn. type ScratchSrcFunc func(ctx context.Context, name string) (src *source.Source, cleanFn func() error, err error) -// Sources provides a mechanism for getting Pool instances. +// Sources provides a mechanism for getting Grip instances. // Note that at this time instances returned by Open are cached // and then closed by Close. This may be a bad approach. type Sources struct { @@ -39,7 +39,7 @@ type Sources struct { mu sync.Mutex scratchSrcFn ScratchSrcFunc files *source.Files - pools map[string]Pool + grips map[string]Grip clnup *cleanup.Cleanup } @@ -53,22 +53,22 @@ func NewSources(log *slog.Logger, drvrs Provider, mu: sync.Mutex{}, scratchSrcFn: scratchSrcFn, files: files, - pools: map[string]Pool{}, + grips: map[string]Grip{}, clnup: cleanup.New(), } } -// Open returns an opened Pool for src. The returned Pool +// Open returns an opened Grip for src. The returned Grip // may be cached and returned on future invocations for the // same source (where each source fields is identical). // Thus, the caller should typically not close -// the Pool: it will be closed via d.Close. +// the Grip: it will be closed via d.Close. // // NOTE: This entire logic re caching/not-closing is a bit sketchy, // and needs to be revisited. // -// Open implements PoolOpener. -func (ss *Sources) Open(ctx context.Context, src *source.Source) (Pool, error) { +// Open implements GripOpener. +func (ss *Sources) Open(ctx context.Context, src *source.Source) (Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) ss.mu.Lock() defer ss.mu.Unlock() @@ -102,13 +102,13 @@ func (ss *Sources) getKey(src *source.Source) string { return src.Handle + "_" + src.Hash() } -func (ss *Sources) doOpen(ctx context.Context, src *source.Source) (Pool, error) { +func (ss *Sources) doOpen(ctx context.Context, src *source.Source) (Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) key := ss.getKey(src) - pool, ok := ss.pools[key] + grip, ok := ss.grips[key] if ok { - return pool, nil + return grip, nil } drvr, err := ss.drvrs.DriverFor(src.Type) @@ -120,21 +120,21 @@ func (ss *Sources) doOpen(ctx context.Context, src *source.Source) (Pool, error) o := options.Merge(baseOptions, src.Options) ctx = options.NewContext(ctx, o) - pool, err = drvr.Open(ctx, src) + grip, err = drvr.Open(ctx, src) if err != nil { return nil, err } - ss.clnup.AddC(pool) + ss.clnup.AddC(grip) - ss.pools[key] = pool - return pool, nil + ss.grips[key] = grip + return grip, nil } // OpenScratch returns a scratch database instance. It is not -// necessary for the caller to close the returned Pool as +// necessary for the caller to close the returned Grip as // its Close method will be invoked by d.Close. -func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Pool, error) { +func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Grip, error) { const msgCloseScratch = "Close scratch db" cacheDir, srcCacheDBFilepath, _, err := ss.getCachePaths(src) @@ -159,8 +159,8 @@ func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Pool, e return nil, err } - var backingPool Pool - backingPool, err = backingDrvr.Open(ctx, scratchSrc) + var backingGrip Grip + backingGrip, err = backingDrvr.Open(ctx, scratchSrc) if err != nil { lg.WarnIfFuncError(ss.log, msgCloseScratch, cleanFn) return nil, err @@ -173,34 +173,33 @@ func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Pool, e ss.clnup.AddE(cleanFn) } - return backingPool, nil + return backingGrip, nil } -// OpenIngest opens a pool for src, using ingestFn to ingest -// the source data if necessary. +// OpenIngest implements driver.IngestOpener. func (ss *Sources) OpenIngest(ctx context.Context, src *source.Source, allowCache bool, - ingestFn func(ctx context.Context, dest Pool) error, -) (Pool, error) { - var pool Pool + ingestFn func(ctx context.Context, dest Grip) error, +) (Grip, error) { + var grip Grip var err error if !allowCache || src.Handle == source.StdinHandle { // We don't currently cache stdin. Probably we never will? - pool, err = ss.openIngestNoCache(ctx, src, ingestFn) + grip, err = ss.openIngestNoCache(ctx, src, ingestFn) } else { - pool, err = ss.openIngestCache(ctx, src, ingestFn) + grip, err = ss.openIngestCache(ctx, src, ingestFn) } if err != nil { return nil, err } - return pool, nil + return grip, nil } func (ss *Sources) openIngestNoCache(ctx context.Context, src *source.Source, - ingestFn func(ctx context.Context, destPool Pool) error, -) (Pool, error) { + ingestFn func(ctx context.Context, destGrip Grip) error, +) (Grip, error) { log := lg.FromContext(ctx) impl, err := ss.OpenScratch(ctx, src) if err != nil { @@ -227,8 +226,8 @@ func (ss *Sources) openIngestNoCache(ctx context.Context, src *source.Source, } func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, - ingestFn func(ctx context.Context, destPool Pool) error, -) (Pool, error) { + ingestFn func(ctx context.Context, destGrip Grip) error, +) (Grip, error) { log := lg.FromContext(ctx) lock, err := ss.acquireLock(ctx, src) @@ -261,7 +260,7 @@ func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, } var ( - impl Pool + impl Grip foundCached bool ) if impl, foundCached, err = ss.openCachedFor(ctx, src); err != nil { @@ -370,7 +369,7 @@ func (ss *Sources) getLockfileFor(src *source.Source) (lockfile.Lockfile, error) return lockfile.New(lockPath) } -func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Pool, bool, error) { +func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Grip, bool, error) { _, cacheDBPath, checksumsPath, err := ss.getCachePaths(src) if err != nil { return nil, false, err @@ -434,12 +433,12 @@ func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Pool, Type: backingType, } - backingPool, err := ss.doOpen(ctx, backingSrc) + backingGrip, err := ss.doOpen(ctx, backingSrc) if err != nil { return nil, false, errz.Wrapf(err, "open cached DB for source %s", src.Handle) } - return backingPool, true, nil + return backingGrip, true, nil } // OpenJoin opens an appropriate database for use as @@ -451,7 +450,7 @@ func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Pool, // location for the join to occur (to minimize copying of data for // the join etc.). Currently the implementation simply delegates // to OpenScratch. -func (ss *Sources) OpenJoin(ctx context.Context, srcs ...*source.Source) (Pool, error) { +func (ss *Sources) OpenJoin(ctx context.Context, srcs ...*source.Source) (Grip, error) { var names []string for _, src := range srcs { names = append(names, src.Handle[1:]) diff --git a/libsq/libsq.go b/libsq/libsq.go index d901ee02d..0e233cf69 100644 --- a/libsq/libsq.go +++ b/libsq/libsq.go @@ -122,26 +122,26 @@ func SLQ2SQL(ctx context.Context, qc *QueryContext, query string) (targetSQL str // QuerySQL executes the SQL query, writing the results to recw. If db is // non-nil, the query is executed against it. Otherwise, the connection is -// obtained from pool. +// obtained from grip. // Note that QuerySQL may return before recw has finished writing, thus the // caller may wish to wait for recw to complete. -// The caller is responsible for closing pool (and db, if non-nil). -func QuerySQL(ctx context.Context, pool driver.Pool, db sqlz.DB, //nolint:funlen +// The caller is responsible for closing grip (and db, if non-nil). +func QuerySQL(ctx context.Context, grip driver.Grip, db sqlz.DB, //nolint:funlen recw RecordWriter, query string, args ...any, ) error { log := lg.FromContext(ctx) - errw := pool.SQLDriver().ErrWrapFunc() + errw := grip.SQLDriver().ErrWrapFunc() if db == nil { var err error - if db, err = pool.DB(ctx); err != nil { + if db, err = grip.DB(ctx); err != nil { return err } } rows, err := db.QueryContext(ctx, query, args...) if err != nil { - err = errz.Wrapf(errw(err), `SQL query against %s failed: %s`, pool.Source().Handle, query) + err = errz.Wrapf(errw(err), `SQL query against %s failed: %s`, grip.Source().Handle, query) select { case <-ctx.Done(): // If the context was cancelled, it's probably more accurate @@ -196,7 +196,7 @@ func QuerySQL(ctx context.Context, pool driver.Pool, db sqlz.DB, //nolint:funlen } } - drvr := pool.SQLDriver() + drvr := grip.SQLDriver() recMeta, recFromScanRowFn, err := drvr.RecordMeta(ctx, colTypes) if err != nil { return errw(err) @@ -226,7 +226,7 @@ func QuerySQL(ctx context.Context, pool driver.Pool, db sqlz.DB, //nolint:funlen err = rows.Scan(scanRow...) if err != nil { cancelFn() - return errz.Wrapf(errw(err), "query against %s", pool.Source().Handle) + return errz.Wrapf(errw(err), "query against %s", grip.Source().Handle) } // recFromScanRowFn returns a new Record with appropriate diff --git a/libsq/pipeline.go b/libsq/pipeline.go index 264eceb90..cecf01e9a 100644 --- a/libsq/pipeline.go +++ b/libsq/pipeline.go @@ -41,16 +41,16 @@ type pipeline struct { rc *render.Context // tasks contains tasks that must be completed before targetSQL - // is executed against targetPool. Typically tasks is used to + // is executed against targetGrip. Typically tasks is used to // set up the joindb before it is queried. tasks []tasker - // targetSQL is the ultimate SQL query to be executed against targetPool. + // targetSQL is the ultimate SQL query to be executed against targetGrip. targetSQL string - // targetPool is the destination for the ultimate SQL query to + // targetGrip is the destination for the ultimate SQL query to // be executed against. - targetPool driver.Pool + targetGrip driver.Grip } // newPipeline parses query, returning a pipeline prepared for @@ -85,11 +85,11 @@ func (p *pipeline) execute(ctx context.Context, recw RecordWriter) error { log := lg.FromContext(ctx) log.Debug( "Execute SQL query", - lga.Src, p.targetPool.Source(), + lga.Src, p.targetGrip.Source(), lga.SQL, p.targetSQL, ) - errw := p.targetPool.SQLDriver().ErrWrapFunc() + errw := p.targetGrip.SQLDriver().ErrWrapFunc() // TODO: The tasks might like to be executed in parallel. However, // what happens if a task does something that is session/connection-dependent? @@ -105,7 +105,7 @@ func (p *pipeline) execute(ctx context.Context, recw RecordWriter) error { // If there's pre/post exec work to do, we need to // obtain a connection from the pool. We are responsible // for closing these resources. - db, err := p.targetPool.DB(ctx) + db, err := p.targetGrip.DB(ctx) if err != nil { return errw(err) } @@ -123,7 +123,7 @@ func (p *pipeline) execute(ctx context.Context, recw RecordWriter) error { } } - if err := QuerySQL(ctx, p.targetPool, conn, recw, p.targetSQL); err != nil { + if err := QuerySQL(ctx, p.targetGrip, conn, recw, p.targetSQL); err != nil { return err } @@ -187,7 +187,7 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { src = p.qc.Collection.Active() if src == nil || !p.qc.Sources.IsSQLSource(src) { log.Debug("No active SQL source, will use scratchdb.") - // REVISIT: ScratchPoolOpener needs a source, so we just make one up. + // REVISIT: Grips.OpenScratch needs a source, so we just make one up. ephemeralSrc := &source.Source{ Type: drivertype.None, Handle: "@scratch_" + stringz.Uniq8(), @@ -195,15 +195,15 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { // FIXME: We really want to change the signature of OpenScratch to // just need a name, not a source. - p.targetPool, err = p.qc.Sources.OpenScratch(ctx, ephemeralSrc) + p.targetGrip, err = p.qc.Sources.OpenScratch(ctx, ephemeralSrc) if err != nil { return err } p.rc = &render.Context{ - Renderer: p.targetPool.SQLDriver().Renderer(), + Renderer: p.targetGrip.SQLDriver().Renderer(), Args: p.qc.Args, - Dialect: p.targetPool.SQLDriver().Dialect(), + Dialect: p.targetGrip.SQLDriver().Dialect(), } return nil } @@ -214,14 +214,14 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { } // At this point, src is non-nil. - if p.targetPool, err = p.qc.Sources.Open(ctx, src); err != nil { + if p.targetGrip, err = p.qc.Sources.Open(ctx, src); err != nil { return err } p.rc = &render.Context{ - Renderer: p.targetPool.SQLDriver().Renderer(), + Renderer: p.targetGrip.SQLDriver().Renderer(), Args: p.qc.Args, - Dialect: p.targetPool.SQLDriver().Dialect(), + Dialect: p.targetGrip.SQLDriver().Dialect(), } return nil @@ -231,7 +231,7 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { // // When this function returns, pipeline.rc will be set. func (p *pipeline) prepareFromTable(ctx context.Context, tblSel *ast.TblSelectorNode) (fromClause string, - fromPool driver.Pool, err error, + fromGrip driver.Grip, err error, ) { handle := tblSel.Handle() if handle == "" { @@ -246,16 +246,16 @@ func (p *pipeline) prepareFromTable(ctx context.Context, tblSel *ast.TblSelector return "", nil, err } - fromPool, err = p.qc.Sources.Open(ctx, src) + fromGrip, err = p.qc.Sources.Open(ctx, src) if err != nil { return "", nil, err } - rndr := fromPool.SQLDriver().Renderer() + rndr := fromGrip.SQLDriver().Renderer() p.rc = &render.Context{ Renderer: rndr, Args: p.qc.Args, - Dialect: fromPool.SQLDriver().Dialect(), + Dialect: fromGrip.SQLDriver().Dialect(), } fromClause, err = rndr.FromTable(p.rc, tblSel) @@ -263,7 +263,7 @@ func (p *pipeline) prepareFromTable(ctx context.Context, tblSel *ast.TblSelector return "", nil, err } - return fromClause, fromPool, nil + return fromClause, fromGrip, nil } // joinClause models the SQL "JOIN" construct. @@ -319,7 +319,7 @@ func (jc *joinClause) isSingleSource() bool { // // When this function returns, pipeline.rc will be set. func (p *pipeline) prepareFromJoin(ctx context.Context, jc *joinClause) (fromClause string, - fromConn driver.Pool, err error, + fromConn driver.Grip, err error, ) { if jc.isSingleSource() { return p.joinSingleSource(ctx, jc) @@ -332,23 +332,23 @@ func (p *pipeline) prepareFromJoin(ctx context.Context, jc *joinClause) (fromCla // // On return, pipeline.rc will be set. func (p *pipeline) joinSingleSource(ctx context.Context, jc *joinClause) (fromClause string, - fromPool driver.Pool, err error, + fromGrip driver.Grip, err error, ) { src, err := p.qc.Collection.Get(jc.leftTbl.Handle()) if err != nil { return "", nil, err } - fromPool, err = p.qc.Sources.Open(ctx, src) + fromGrip, err = p.qc.Sources.Open(ctx, src) if err != nil { return "", nil, err } - rndr := fromPool.SQLDriver().Renderer() + rndr := fromGrip.SQLDriver().Renderer() p.rc = &render.Context{ Renderer: rndr, Args: p.qc.Args, - Dialect: fromPool.SQLDriver().Dialect(), + Dialect: fromGrip.SQLDriver().Dialect(), } fromClause, err = rndr.Join(p.rc, jc.leftTbl, jc.joins) @@ -356,7 +356,7 @@ func (p *pipeline) joinSingleSource(ctx context.Context, jc *joinClause) (fromCl return "", nil, err } - return fromClause, fromPool, nil + return fromClause, fromGrip, nil } // joinCrossSource returns a FROM clause that forms part of @@ -364,7 +364,7 @@ func (p *pipeline) joinSingleSource(ctx context.Context, jc *joinClause) (fromCl // // On return, pipeline.rc will be set. func (p *pipeline) joinCrossSource(ctx context.Context, jc *joinClause) (fromClause string, - fromDB driver.Pool, err error, + fromDB driver.Grip, err error, ) { handles := jc.handles() srcs := make([]*source.Source, 0, len(handles)) @@ -377,16 +377,16 @@ func (p *pipeline) joinCrossSource(ctx context.Context, jc *joinClause) (fromCla } // Open the join db - joinPool, err := p.qc.Sources.OpenJoin(ctx, srcs...) + joinGrip, err := p.qc.Sources.OpenJoin(ctx, srcs...) if err != nil { return "", nil, err } - rndr := joinPool.SQLDriver().Renderer() + rndr := joinGrip.SQLDriver().Renderer() p.rc = &render.Context{ Renderer: rndr, Args: p.qc.Args, - Dialect: joinPool.SQLDriver().Dialect(), + Dialect: joinGrip.SQLDriver().Dialect(), } leftHandle := jc.leftTbl.Handle() @@ -403,15 +403,15 @@ func (p *pipeline) joinCrossSource(ctx context.Context, jc *joinClause) (fromCla if src, err = p.qc.Collection.Get(handle); err != nil { return "", nil, err } - var db driver.Pool + var db driver.Grip if db, err = p.qc.Sources.Open(ctx, src); err != nil { return "", nil, err } task := &joinCopyTask{ - fromPool: db, + fromGrip: db, fromTbl: tbl.Table(), - toPool: joinPool, + toGrip: joinGrip, toTbl: tbl.TblAliasOrName(), } @@ -425,7 +425,7 @@ func (p *pipeline) joinCrossSource(ctx context.Context, jc *joinClause) (fromCla return "", nil, err } - return fromClause, joinPool, nil + return fromClause, joinGrip, nil } // tasker is the interface for executing a DB task. @@ -436,35 +436,35 @@ type tasker interface { // joinCopyTask is a specification of a table data copy task to be performed // for a cross-source join. That is, the data in fromDB.fromTblName will -// be copied to a table in toPool. If colNames is +// be copied to a table in toGrip. If colNames is // empty, all cols in fromTbl are to be copied. type joinCopyTask struct { - fromPool driver.Pool + fromGrip driver.Grip fromTbl tablefq.T - toPool driver.Pool + toGrip driver.Grip toTbl tablefq.T } func (jt *joinCopyTask) executeTask(ctx context.Context) error { - return execCopyTable(ctx, jt.fromPool, jt.fromTbl, jt.toPool, jt.toTbl) + return execCopyTable(ctx, jt.fromGrip, jt.fromTbl, jt.toGrip, jt.toTbl) } -// execCopyTable performs the work of copying fromDB.fromTbl to destPool.destTbl. -func execCopyTable(ctx context.Context, fromDB driver.Pool, fromTbl tablefq.T, - destPool driver.Pool, destTbl tablefq.T, +// execCopyTable performs the work of copying fromDB.fromTbl to destGrip.destTbl. +func execCopyTable(ctx context.Context, fromDB driver.Grip, fromTbl tablefq.T, + destGrip driver.Grip, destTbl tablefq.T, ) error { log := lg.FromContext(ctx) - createTblHook := func(ctx context.Context, originRecMeta record.Meta, destPool driver.Pool, + createTblHook := func(ctx context.Context, originRecMeta record.Meta, destGrip driver.Grip, tx sqlz.DB, ) error { destColNames := originRecMeta.Names() destColKinds := originRecMeta.Kinds() destTblDef := sqlmodel.NewTableDef(destTbl.Table, destColNames, destColKinds) - err := destPool.SQLDriver().CreateTable(ctx, tx, destTblDef) + err := destGrip.SQLDriver().CreateTable(ctx, tx, destTblDef) if err != nil { - return errz.Wrapf(err, "failed to create dest table %s.%s", destPool.Source().Handle, destTbl) + return errz.Wrapf(err, "failed to create dest table %s.%s", destGrip.Source().Handle, destTbl) } return nil @@ -472,24 +472,24 @@ func execCopyTable(ctx context.Context, fromDB driver.Pool, fromTbl tablefq.T, inserter := NewDBWriter( "Copy records", - destPool, + destGrip, destTbl.Table, - driver.OptTuningRecChanSize.Get(destPool.Source().Options), + driver.OptTuningRecChanSize.Get(destGrip.Source().Options), createTblHook, ) query := "SELECT * FROM " + fromTbl.Render(fromDB.SQLDriver().Dialect().Enquote) err := QuerySQL(ctx, fromDB, nil, inserter, query) if err != nil { - return errz.Wrapf(err, "insert %s.%s failed", destPool.Source().Handle, destTbl) + return errz.Wrapf(err, "insert %s.%s failed", destGrip.Source().Handle, destTbl) } affected, err := inserter.Wait() // Stop for the writer to finish processing if err != nil { - return errz.Wrapf(err, "insert %s.%s failed", destPool.Source().Handle, destTbl) + return errz.Wrapf(err, "insert %s.%s failed", destGrip.Source().Handle, destTbl) } log.Debug("Copied rows to dest", lga.Count, affected, lga.From, fmt.Sprintf("%s.%s", fromDB.Source().Handle, fromTbl), - lga.To, fmt.Sprintf("%s.%s", destPool.Source().Handle, destTbl)) + lga.To, fmt.Sprintf("%s.%s", destGrip.Source().Handle, destTbl)) return nil } diff --git a/libsq/prepare.go b/libsq/prepare.go index 006734437..4382363ee 100644 --- a/libsq/prepare.go +++ b/libsq/prepare.go @@ -7,9 +7,9 @@ import ( ) // prepare prepares the pipeline to execute queryModel. -// When this method returns, targetPool and targetSQL will be set, +// When this method returns, targetGrip and targetSQL will be set, // as will any tasks (which may be empty). The tasks must be executed -// against targetPool before targetSQL is executed (the pipeline.execute +// against targetGrip before targetSQL is executed (the pipeline.execute // method does this work). func (p *pipeline) prepare(ctx context.Context, qm *queryModel) error { var err error @@ -24,11 +24,11 @@ func (p *pipeline) prepare(ctx context.Context, qm *queryModel) error { } case len(qm.Joins) > 0: jc := &joinClause{leftTbl: qm.Table, joins: qm.Joins} - if frags.From, p.targetPool, err = p.prepareFromJoin(ctx, jc); err != nil { + if frags.From, p.targetGrip, err = p.prepareFromJoin(ctx, jc); err != nil { return err } default: - if frags.From, p.targetPool, err = p.prepareFromTable(ctx, qm.Table); err != nil { + if frags.From, p.targetGrip, err = p.prepareFromTable(ctx, qm.Table); err != nil { return err } } diff --git a/testh/testh.go b/testh/testh.go index 764293f20..c02cb6a11 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -142,18 +142,18 @@ func New(t testing.TB, opts ...Option) *Helper { return h } -// NewWith is a convenience wrapper for New that also returns -// a Source for handle, the driver.SQLDriver, driver.Pool, +// NewWith is a convenience wrapper for New, that also returns +// the source.Source for handle, the driver.SQLDriver, driver.Grip, // and the *sql.DB. -func NewWith(t testing.TB, handle string) (*Helper, *source.Source, driver.SQLDriver, driver.Pool, *sql.DB) { +func NewWith(t testing.TB, handle string) (*Helper, *source.Source, driver.SQLDriver, driver.Grip, *sql.DB) { th := New(t) src := th.Source(handle) - pool := th.Open(src) - drvr := pool.SQLDriver() - db, err := pool.DB(th.Context) + grip := th.Open(src) + drvr := grip.SQLDriver() + db, err := grip.DB(th.Context) require.NoError(t, err) - return th, src, drvr, pool, db + return th, src, drvr, grip, db } func (h *Helper) init() { @@ -363,49 +363,49 @@ func (h *Helper) NewCollection(handles ...string) *source.Collection { return coll } -// Open opens a driver.Pool for src via h's internal Sources +// Open opens a driver.Grip for src via h's internal Sources // instance: thus subsequent calls to Open may return the -// same Pool instance. The opened driver.Pool will be closed +// same driver.Grip instance. The opened driver.Grip will be closed // during h.Close. -func (h *Helper) Open(src *source.Source) driver.Pool { +func (h *Helper) Open(src *source.Source) driver.Grip { ctx, cancelFn := context.WithTimeout(h.Context, h.dbOpenTimeout) defer cancelFn() - pool, err := h.Sources().Open(ctx, src) + grip, err := h.Sources().Open(ctx, src) require.NoError(h.T, err) - db, err := pool.DB(ctx) + db, err := grip.DB(ctx) require.NoError(h.T, err) require.NoError(h.T, db.PingContext(ctx)) - return pool + return grip } // OpenDB is a convenience method for getting the sql.DB for src. // The returned sql.DB is closed during h.Close, via the closing -// of its parent driver.Pool. +// of its parent driver.Grip. func (h *Helper) OpenDB(src *source.Source) *sql.DB { - pool := h.Open(src) - db, err := pool.DB(h.Context) + grip := h.Open(src) + db, err := grip.DB(h.Context) require.NoError(h.T, err) return db } -// openNew opens a new driver.Pool. It is the caller's responsibility -// to close the returned Pool. Unlike method Open, this method +// openNew opens a new driver.Grip. It is the caller's responsibility +// to close the returned Grip. Unlike method Open, this method // will always invoke the driver's Open method. // // Some of Helper's methods (e.g. DropTable) need to use openNew rather -// than Open, as the Pool returned by Open can be closed by test code, +// than Open, as the Grip returned by Open can be closed by test code, // potentially causing problems during Cleanup. -func (h *Helper) openNew(src *source.Source) driver.Pool { +func (h *Helper) openNew(src *source.Source) driver.Grip { h.Log.Debug("openNew", lga.Src, src) reg := h.Registry() drvr, err := reg.DriverFor(src.Type) require.NoError(h.T, err) - pool, err := drvr.Open(h.Context, src) + grip, err := drvr.Open(h.Context, src) require.NoError(h.T, err) - return pool + return grip } // SQLDriverFor is a convenience method to get src's driver.SQLDriver. @@ -431,12 +431,12 @@ func (h *Helper) DriverFor(src *source.Source) driver.Driver { // RowCount returns the result of "SELECT COUNT(*) FROM tbl", // failing h's test on any error. func (h *Helper) RowCount(src *source.Source, tbl string) int64 { - pool := h.openNew(src) - defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, pool) + grip := h.openNew(src) + defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, grip) - query := "SELECT COUNT(*) FROM " + pool.SQLDriver().Dialect().Enquote(tbl) + query := "SELECT COUNT(*) FROM " + grip.SQLDriver().Dialect().Enquote(tbl) var count int64 - db, err := pool.DB(h.Context) + db, err := grip.DB(h.Context) require.NoError(h.T, err) require.NoError(h.T, db.QueryRowContext(h.Context, query).Scan(&count)) @@ -449,13 +449,13 @@ func (h *Helper) RowCount(src *source.Source, tbl string) int64 { func (h *Helper) CreateTable(dropAfter bool, src *source.Source, tblDef *sqlmodel.TableDef, data ...[]any, ) (affected int64) { - pool := h.openNew(src) - defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, pool) + grip := h.openNew(src) + defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, grip) - db, err := pool.DB(h.Context) + db, err := grip.DB(h.Context) require.NoError(h.T, err) - require.NoError(h.T, pool.SQLDriver().CreateTable(h.Context, db, tblDef)) + require.NoError(h.T, grip.SQLDriver().CreateTable(h.Context, db, tblDef)) h.T.Logf("Created table %s.%s", src.Handle, tblDef.Name) if dropAfter { @@ -477,11 +477,11 @@ func (h *Helper) Insert(src *source.Source, tbl string, cols []string, records . return 0 } - pool := h.openNew(src) - defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, pool) + grip := h.openNew(src) + defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, grip) - drvr := pool.SQLDriver() - db, err := pool.DB(h.Context) + drvr := grip.SQLDriver() + db, err := grip.DB(h.Context) require.NoError(h.T, err) conn, err := db.Conn(h.Context) @@ -550,13 +550,13 @@ func (h *Helper) CopyTable( toTable.Table = stringz.UniqTableName(fromTable.Table) } - pool := h.openNew(src) - defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, pool) + grip := h.openNew(src) + defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, grip) - db, err := pool.DB(h.Context) + db, err := grip.DB(h.Context) require.NoError(h.T, err) - copied, err := pool.SQLDriver().CopyTable( + copied, err := grip.SQLDriver().CopyTable( h.Context, db, fromTable, @@ -580,28 +580,28 @@ func (h *Helper) CopyTable( // DropTable drops tbl from src. func (h *Helper) DropTable(src *source.Source, tbl tablefq.T) { - pool := h.openNew(src) - defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, pool) + grip := h.openNew(src) + defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, grip) - db, err := pool.DB(h.Context) + db, err := grip.DB(h.Context) require.NoError(h.T, err) - require.NoError(h.T, pool.SQLDriver().DropTable(h.Context, db, tbl, true)) + require.NoError(h.T, grip.SQLDriver().DropTable(h.Context, db, tbl, true)) h.Log.Debug("Dropped table", lga.Target, source.Target(src, tbl.Table)) } // QuerySQL uses libsq.QuerySQL to execute SQL query // against src, returning a sink to which all records have // been written. Typically the db arg is nil, and QuerySQL uses the -// same driver.Pool instance as returned by Helper.Open. If db +// same driver.Grip instance as returned by Helper.Open. If db // is non-nil, it is passed to libsq.QuerySQL (e.g. the query needs to // execute against a sql.Tx), and the caller is responsible for closing db. func (h *Helper) QuerySQL(src *source.Source, db sqlz.DB, query string, args ...any) (*RecordSink, error) { - pool := h.Open(src) + grip := h.Open(src) sink := &RecordSink{} recw := output.NewRecordWriterAdapter(h.Context, sink) - err := libsq.QuerySQL(h.Context, pool, db, recw, query, args...) + err := libsq.QuerySQL(h.Context, grip, db, recw, query, args...) if err != nil { return nil, err } @@ -648,7 +648,7 @@ func (h *Helper) QuerySLQ(query string, args map[string]string) (*RecordSink, er // ExecSQL is a convenience wrapper for sql.DB.Exec that returns the // rows affected, failing on any error. Note that ExecSQL uses the -// same Pool instance as returned by h.Open. +// same Grip instance as returned by h.Open. func (h *Helper) ExecSQL(src *source.Source, query string, args ...any) (affected int64) { db := h.OpenDB(src) @@ -689,8 +689,8 @@ func (h *Helper) InsertDefaultRow(src *source.Source, tbl string) { // TruncateTable truncates tbl in src. func (h *Helper) TruncateTable(src *source.Source, tbl string) (affected int64) { - pool := h.openNew(src) - defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, pool) + grip := h.openNew(src) + defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, grip) affected, err := h.DriverFor(src).Truncate(h.Context, src, tbl, true) require.NoError(h.T, err) @@ -763,22 +763,22 @@ func (h *Helper) Files() *source.Files { // SourceMetadata returns metadata for src. func (h *Helper) SourceMetadata(src *source.Source) (*metadata.Source, error) { - pools, err := h.Sources().Open(h.Context, src) + grip, err := h.Sources().Open(h.Context, src) if err != nil { return nil, err } - return pools.SourceMetadata(h.Context, false) + return grip.SourceMetadata(h.Context, false) } // TableMetadata returns metadata for src's table. func (h *Helper) TableMetadata(src *source.Source, tbl string) (*metadata.Table, error) { - pools, err := h.Sources().Open(h.Context, src) + grip, err := h.Sources().Open(h.Context, src) if err != nil { return nil, err } - return pools.TableMetadata(h.Context, tbl) + return grip.TableMetadata(h.Context, tbl) } // DiffDB fails the test if src's metadata is substantially different From a32d23a362119c4cbf217ab7cac3bbb3fe6a7ee7 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 23:16:37 -0700 Subject: [PATCH 051/195] Grip rename almost done --- cli/cmd_inspect.go | 2 +- cli/cmd_slq.go | 4 +- cli/cmd_sql.go | 4 +- cli/cmd_tbl.go | 4 +- cli/complete.go | 6 +- cli/diff/source.go | 2 +- cli/diff/table.go | 2 +- cli/run.go | 12 +-- cli/run/run.go | 6 +- drivers/csv/csv.go | 4 +- drivers/json/json.go | 4 +- drivers/userdriver/userdriver.go | 4 +- drivers/xlsx/xlsx.go | 4 +- libsq/driver/driver.go | 51 ------------- libsq/driver/grip.go | 60 +++++++++++++++ libsq/driver/grips.go | 124 +++++++++++++++---------------- libsq/libsq.go | 2 +- testh/testh.go | 22 +++--- 18 files changed, 163 insertions(+), 154 deletions(-) create mode 100644 libsq/driver/grip.go diff --git a/cli/cmd_inspect.go b/cli/cmd_inspect.go index 05c543604..3a6b4789d 100644 --- a/cli/cmd_inspect.go +++ b/cli/cmd_inspect.go @@ -125,7 +125,7 @@ func execInspect(cmd *cobra.Command, args []string) error { return err } - grip, err := ru.Sources.Open(ctx, src) + grip, err := ru.Grips.Open(ctx, src) if err != nil { return errz.Wrapf(err, "failed to inspect %s", src.Handle) } diff --git a/cli/cmd_slq.go b/cli/cmd_slq.go index e7c7efec1..343117795 100644 --- a/cli/cmd_slq.go +++ b/cli/cmd_slq.go @@ -135,7 +135,7 @@ func execSLQInsert(ctx context.Context, ru *run.Run, mArgs map[string]string, ctx, cancelFn := context.WithCancel(ctx) defer cancelFn() - destGrip, err := ru.Sources.Open(ctx, destSrc) + destGrip, err := ru.Grips.Open(ctx, destSrc) if err != nil { return err } @@ -204,7 +204,7 @@ func execSLQPrint(ctx context.Context, ru *run.Run, mArgs map[string]string) err // // $ cat something.xlsx | sq @stdin.sheet1 func preprocessUserSLQ(ctx context.Context, ru *run.Run, args []string) (string, error) { - log, reg, grips, coll := lg.FromContext(ctx), ru.DriverRegistry, ru.Sources, ru.Config.Collection + log, reg, grips, coll := lg.FromContext(ctx), ru.DriverRegistry, ru.Grips, ru.Config.Collection activeSrc := coll.Active() if len(args) == 0 { diff --git a/cli/cmd_sql.go b/cli/cmd_sql.go index a14c58660..8bb2f4d46 100644 --- a/cli/cmd_sql.go +++ b/cli/cmd_sql.go @@ -120,7 +120,7 @@ func execSQL(cmd *cobra.Command, args []string) error { // to the configured writer. func execSQLPrint(ctx context.Context, ru *run.Run, fromSrc *source.Source) error { args := ru.Args - grip, err := ru.Sources.Open(ctx, fromSrc) + grip, err := ru.Grips.Open(ctx, fromSrc) if err != nil { return err } @@ -140,7 +140,7 @@ func execSQLInsert(ctx context.Context, ru *run.Run, fromSrc, destSrc *source.Source, destTbl string, ) error { args := ru.Args - grips := ru.Sources + grips := ru.Grips ctx, cancelFn := context.WithCancel(ctx) defer cancelFn() diff --git a/cli/cmd_tbl.go b/cli/cmd_tbl.go index c92b3f713..e1f012652 100644 --- a/cli/cmd_tbl.go +++ b/cli/cmd_tbl.go @@ -122,7 +122,7 @@ func execTblCopy(cmd *cobra.Command, args []string) error { } var grip driver.Grip - grip, err = ru.Sources.Open(ctx, tblHandles[0].src) + grip, err = ru.Grips.Open(ctx, tblHandles[0].src) if err != nil { return err } @@ -255,7 +255,7 @@ func execTblDrop(cmd *cobra.Command, args []string) (err error) { } var grip driver.Grip - if grip, err = ru.Sources.Open(ctx, tblH.src); err != nil { + if grip, err = ru.Grips.Open(ctx, tblH.src); err != nil { return err } diff --git a/cli/complete.go b/cli/complete.go index 0959cbb8e..4db48df42 100644 --- a/cli/complete.go +++ b/cli/complete.go @@ -140,7 +140,7 @@ func completeSLQ(cmd *cobra.Command, args []string, toComplete string) ([]string // completeDriverType is a completionFunc that suggests drivers. func completeDriverType(cmd *cobra.Command, _ []string, _ string) ([]string, cobra.ShellCompDirective) { ru := getRun(cmd) - if ru.Sources == nil { + if ru.Grips == nil { if err := preRun(cmd, ru); err != nil { lg.Unexpected(logFrom(cmd), err) return nil, cobra.ShellCompDirectiveError @@ -383,7 +383,7 @@ func (c activeSchemaCompleter) complete(cmd *cobra.Command, args []string, toCom ctx, cancelFn := context.WithTimeout(cmd.Context(), OptShellCompletionTimeout.Get(ru.Config.Options)) defer cancelFn() - grip, err := ru.Sources.Open(ctx, src) + grip, err := ru.Grips.Open(ctx, src) if err != nil { lg.Unexpected(log, err) return nil, cobra.ShellCompDirectiveError @@ -759,7 +759,7 @@ func getTableNamesForHandle(ctx context.Context, ru *run.Run, handle string) ([] return nil, err } - grip, err := ru.Sources.Open(ctx, src) + grip, err := ru.Grips.Open(ctx, src) if err != nil { return nil, err } diff --git a/cli/diff/source.go b/cli/diff/source.go index baba129e1..f05a5d25b 100644 --- a/cli/diff/source.go +++ b/cli/diff/source.go @@ -196,7 +196,7 @@ func fetchSourceMeta(ctx context.Context, ru *run.Run, handle string) (*source.S if err != nil { return nil, nil, err } - grip, err := ru.Sources.Open(ctx, src) + grip, err := ru.Grips.Open(ctx, src) if err != nil { return nil, nil, err } diff --git a/cli/diff/table.go b/cli/diff/table.go index 6e6531c15..320a97983 100644 --- a/cli/diff/table.go +++ b/cli/diff/table.go @@ -114,7 +114,7 @@ func buildTableStructureDiff(cfg *Config, showRowCounts bool, td1, td2 *tableDat func fetchTableMeta(ctx context.Context, ru *run.Run, src *source.Source, table string) ( *metadata.Table, error, ) { - grip, err := ru.Sources.Open(ctx, src) + grip, err := ru.Grips.Open(ctx, src) if err != nil { return nil, err } diff --git a/cli/run.go b/cli/run.go index 033d5a816..d77a02446 100644 --- a/cli/run.go +++ b/cli/run.go @@ -157,19 +157,19 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { ru.DriverRegistry = driver.NewRegistry(log) dr := ru.DriverRegistry - ru.Sources = driver.NewSources(log, dr, ru.Files, scratchSrcFunc) - ru.Cleanup.AddC(ru.Sources) + ru.Grips = driver.NewGrips(log, dr, ru.Files, scratchSrcFunc) + ru.Cleanup.AddC(ru.Grips) dr.AddProvider(sqlite3.Type, &sqlite3.Provider{Log: log}) dr.AddProvider(postgres.Type, &postgres.Provider{Log: log}) dr.AddProvider(sqlserver.Type, &sqlserver.Provider{Log: log}) dr.AddProvider(mysql.Type, &mysql.Provider{Log: log}) - csvp := &csv.Provider{Log: log, Ingester: ru.Sources, Files: ru.Files} + csvp := &csv.Provider{Log: log, Ingester: ru.Grips, Files: ru.Files} dr.AddProvider(csv.TypeCSV, csvp) dr.AddProvider(csv.TypeTSV, csvp) ru.Files.AddDriverDetectors(csv.DetectCSV, csv.DetectTSV) - jsonp := &json.Provider{Log: log, Ingester: ru.Sources, Files: ru.Files} + jsonp := &json.Provider{Log: log, Ingester: ru.Grips, Files: ru.Files} dr.AddProvider(json.TypeJSON, jsonp) dr.AddProvider(json.TypeJSONA, jsonp) dr.AddProvider(json.TypeJSONL, jsonp) @@ -180,7 +180,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { json.DetectJSONL(sampleSize), ) - dr.AddProvider(xlsx.Type, &xlsx.Provider{Log: log, Ingester: ru.Sources, Files: ru.Files}) + dr.AddProvider(xlsx.Type, &xlsx.Provider{Log: log, Ingester: ru.Grips, Files: ru.Files}) ru.Files.AddDriverDetectors(xlsx.DetectXLSX) // One day we may have more supported user driver genres. userDriverImporters := map[string]userdriver.ImportFunc{ @@ -210,7 +210,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { Log: log, DriverDef: userDriverDef, ImportFn: importFn, - Ingester: ru.Sources, + Ingester: ru.Grips, Files: ru.Files, } diff --git a/cli/run/run.go b/cli/run/run.go index e5c76ebbc..44ce55a92 100644 --- a/cli/run/run.go +++ b/cli/run/run.go @@ -73,8 +73,8 @@ type Run struct { // Files manages file access. Files *source.Files - // Sources mediates access to driver.Grip instances. - Sources *driver.Sources + // Grips mediates access to driver.Grip instances. + Grips *driver.Grips // Writers holds the various writer types that // the CLI uses to print output. @@ -100,7 +100,7 @@ func (ru *Run) Close() error { func NewQueryContext(ru *Run, args map[string]string) *libsq.QueryContext { return &libsq.QueryContext{ Collection: ru.Config.Collection, - Sources: ru.Sources, + Sources: ru.Grips, Args: args, } } diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index afe6eca04..8f2a1b270 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -28,7 +28,7 @@ const ( // Provider implements driver.Provider. type Provider struct { Log *slog.Logger - Ingester driver.IngestOpener + Ingester driver.GripOpenIngester Files *source.Files } @@ -48,7 +48,7 @@ func (d *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { type driveri struct { log *slog.Logger typ drivertype.Type - ingester driver.IngestOpener + ingester driver.GripOpenIngester files *source.Files } diff --git a/drivers/json/json.go b/drivers/json/json.go index 240549be1..0b1971877 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -36,7 +36,7 @@ const ( // Provider implements driver.Provider. type Provider struct { Log *slog.Logger - Ingester driver.IngestOpener + Ingester driver.GripOpenIngester Files *source.Files } @@ -67,7 +67,7 @@ func (d *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { type driveri struct { typ drivertype.Type ingestFn ingestFunc - ingester driver.IngestOpener + ingester driver.GripOpenIngester files *source.Files } diff --git a/drivers/userdriver/userdriver.go b/drivers/userdriver/userdriver.go index 47839cef4..7128b4487 100644 --- a/drivers/userdriver/userdriver.go +++ b/drivers/userdriver/userdriver.go @@ -31,7 +31,7 @@ type ImportFunc func(ctx context.Context, def *DriverDef, type Provider struct { Log *slog.Logger DriverDef *DriverDef - Ingester driver.IngestOpener + Ingester driver.GripOpenIngester Files *source.Files ImportFn ImportFunc } @@ -66,7 +66,7 @@ type driveri struct { typ drivertype.Type def *DriverDef files *source.Files - ingester driver.IngestOpener + ingester driver.GripOpenIngester ingestFn ImportFunc } diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index e8535b6e9..59666ac73 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -29,7 +29,7 @@ const ( type Provider struct { Log *slog.Logger Files *source.Files - Ingester driver.IngestOpener + Ingester driver.GripOpenIngester } // DriverFor implements driver.Provider. @@ -44,7 +44,7 @@ func (p *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { // Driver implements driver.Driver. type Driver struct { log *slog.Logger - ingester driver.IngestOpener + ingester driver.GripOpenIngester files *source.Files } diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index 14b15e5a0..e3cf69910 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -155,22 +155,6 @@ type Provider interface { DriverFor(typ drivertype.Type) (Driver, error) } -// GripOpener opens a Grip. -type GripOpener interface { - // Open returns a Grip instance for src. - Open(ctx context.Context, src *source.Source) (Grip, error) -} - -// IngestOpener opens a Grip for ingest use. -type IngestOpener interface { - // OpenIngest opens a Grip for src by executing ingestFn, which is - // responsible for ingesting data into dest. If allowCache is false, - // ingest always occurs; if true, the cache is consulted first (and - // ingestFn may not be invoked). - OpenIngest(ctx context.Context, src *source.Source, allowCache bool, - ingestFn func(ctx context.Context, dest Grip) error) (Grip, error) -} - // Driver is the core interface that must be implemented for each type // of data source. type Driver interface { @@ -328,41 +312,6 @@ type SQLDriver interface { DBProperties(ctx context.Context, db sqlz.DB) (map[string]any, error) } -// Grip is the link between a source and its database connection. -// Why is it named Grip? TLDR: all the other names were taken, -// including Handle, Conn, DB, Source, etc. -// -// Grip is conceptually equivalent to stdlib sql.DB, and in fact -// encapsulates a sql.DB instance. The realized sql.DB instance can be -// accessed via the DB method. -type Grip interface { - // DB returns the sql.DB object for this Grip. - // This operation may take a long time if opening the DB requires - // an ingest of data (but note that when an ingest step occurs is - // driver-dependent). - DB(ctx context.Context) (*sql.DB, error) - - // SQLDriver returns the underlying database driver. The type of the SQLDriver - // may be different from the driver type reported by the Source. - SQLDriver() SQLDriver - - // FIXME: Add a method: SourceDriver() Driver. - - // Source returns the source for which this Grip was opened. - Source() *source.Source - - // SourceMetadata returns metadata about the Grip. - // If noSchema is true, schema details are not populated - // on the returned metadata.Source. - SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) - - // TableMetadata returns metadata for the specified table in the Grip. - TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) - - // Close is invoked to close and release any underlying resources. - Close() error -} - // Metadata holds driver metadata. // // TODO: Can driver.Metadata and dialect.Dialect be merged? diff --git a/libsq/driver/grip.go b/libsq/driver/grip.go new file mode 100644 index 000000000..a93846efb --- /dev/null +++ b/libsq/driver/grip.go @@ -0,0 +1,60 @@ +package driver + +import ( + "context" + "database/sql" + + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/metadata" +) + +// Grip is the link between a source and its database connection. +// Why is it named Grip? TLDR: all the other names were taken, +// including Handle, Conn, DB, Source, etc. +// +// Grip is conceptually equivalent to stdlib sql.DB, and in fact +// encapsulates a sql.DB instance. The realized sql.DB instance can be +// accessed via the DB method. +type Grip interface { + // DB returns the sql.DB object for this Grip. + // This operation may take a long time if opening the DB requires + // an ingest of data (but note that when an ingest step occurs is + // driver-dependent). + DB(ctx context.Context) (*sql.DB, error) + + // SQLDriver returns the underlying database driver. The type of the SQLDriver + // may be different from the driver type reported by the Source. + SQLDriver() SQLDriver + + // FIXME: Add a method: SourceDriver() Driver. + + // Source returns the source for which this Grip was opened. + Source() *source.Source + + // SourceMetadata returns metadata about the Grip. + // If noSchema is true, schema details are not populated + // on the returned metadata.Source. + SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) + + // TableMetadata returns metadata for the specified table in the Grip. + TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) + + // Close is invoked to close and release any underlying resources. + Close() error +} + +// GripOpener opens a Grip. +type GripOpener interface { + // Open returns a Grip instance for src. + Open(ctx context.Context, src *source.Source) (Grip, error) +} + +// GripOpenIngester opens a Grip via an ingest function. +type GripOpenIngester interface { + // OpenIngest opens a Grip for src by executing ingestFn, which is + // responsible for ingesting data into dest. If allowCache is false, + // ingest always occurs; if true, the cache is consulted first (and + // ingestFn may not be invoked). + OpenIngest(ctx context.Context, src *source.Source, allowCache bool, + ingestFn func(ctx context.Context, dest Grip) error) (Grip, error) +} diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 74106a5b7..8c8885299 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -24,16 +24,16 @@ import ( "github.com/neilotoole/sq/libsq/source" ) -var _ GripOpener = (*Sources)(nil) +var _ GripOpener = (*Grips)(nil) // ScratchSrcFunc is a function that returns a scratch source. // The caller is responsible for invoking cleanFn. type ScratchSrcFunc func(ctx context.Context, name string) (src *source.Source, cleanFn func() error, err error) -// Sources provides a mechanism for getting Grip instances. +// Grips provides a mechanism for getting Grip instances. // Note that at this time instances returned by Open are cached // and then closed by Close. This may be a bad approach. -type Sources struct { +type Grips struct { log *slog.Logger drvrs Provider mu sync.Mutex @@ -43,11 +43,11 @@ type Sources struct { clnup *cleanup.Cleanup } -// NewSources returns a Sources instances. -func NewSources(log *slog.Logger, drvrs Provider, +// NewGrips returns a Grips instances. +func NewGrips(log *slog.Logger, drvrs Provider, files *source.Files, scratchSrcFn ScratchSrcFunc, -) *Sources { - return &Sources{ +) *Grips { + return &Grips{ log: log, drvrs: drvrs, mu: sync.Mutex{}, @@ -68,25 +68,25 @@ func NewSources(log *slog.Logger, drvrs Provider, // and needs to be revisited. // // Open implements GripOpener. -func (ss *Sources) Open(ctx context.Context, src *source.Source) (Grip, error) { +func (gs *Grips) Open(ctx context.Context, src *source.Source) (Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) - ss.mu.Lock() - defer ss.mu.Unlock() - return ss.doOpen(ctx, src) + gs.mu.Lock() + defer gs.mu.Unlock() + return gs.doOpen(ctx, src) } // DriverFor returns the driver for typ. -func (ss *Sources) DriverFor(typ drivertype.Type) (Driver, error) { - return ss.drvrs.DriverFor(typ) +func (gs *Grips) DriverFor(typ drivertype.Type) (Driver, error) { + return gs.drvrs.DriverFor(typ) } // IsSQLSource returns true if src's driver is a SQLDriver. -func (ss *Sources) IsSQLSource(src *source.Source) bool { +func (gs *Grips) IsSQLSource(src *source.Source) bool { if src == nil { return false } - drvr, err := ss.drvrs.DriverFor(src.Type) + drvr, err := gs.drvrs.DriverFor(src.Type) if err != nil { return false } @@ -98,20 +98,20 @@ func (ss *Sources) IsSQLSource(src *source.Source) bool { return false } -func (ss *Sources) getKey(src *source.Source) string { +func (gs *Grips) getKey(src *source.Source) string { return src.Handle + "_" + src.Hash() } -func (ss *Sources) doOpen(ctx context.Context, src *source.Source) (Grip, error) { +func (gs *Grips) doOpen(ctx context.Context, src *source.Source) (Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) - key := ss.getKey(src) + key := gs.getKey(src) - grip, ok := ss.grips[key] + grip, ok := gs.grips[key] if ok { return grip, nil } - drvr, err := ss.drvrs.DriverFor(src.Type) + drvr, err := gs.drvrs.DriverFor(src.Type) if err != nil { return nil, err } @@ -125,19 +125,19 @@ func (ss *Sources) doOpen(ctx context.Context, src *source.Source) (Grip, error) return nil, err } - ss.clnup.AddC(grip) + gs.clnup.AddC(grip) - ss.grips[key] = grip + gs.grips[key] = grip return grip, nil } // OpenScratch returns a scratch database instance. It is not // necessary for the caller to close the returned Grip as // its Close method will be invoked by d.Close. -func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Grip, error) { +func (gs *Grips) OpenScratch(ctx context.Context, src *source.Source) (Grip, error) { const msgCloseScratch = "Close scratch db" - cacheDir, srcCacheDBFilepath, _, err := ss.getCachePaths(src) + cacheDir, srcCacheDBFilepath, _, err := gs.getCachePaths(src) if err != nil { return nil, err } @@ -146,23 +146,23 @@ func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Grip, e return nil, err } - scratchSrc, cleanFn, err := ss.scratchSrcFn(ctx, srcCacheDBFilepath) + scratchSrc, cleanFn, err := gs.scratchSrcFn(ctx, srcCacheDBFilepath) if err != nil { // if err is non-nil, cleanup is guaranteed to be nil return nil, err } - ss.log.Debug("Opening scratch src", lga.Src, scratchSrc) + gs.log.Debug("Opening scratch src", lga.Src, scratchSrc) - backingDrvr, err := ss.drvrs.DriverFor(scratchSrc.Type) + backingDrvr, err := gs.drvrs.DriverFor(scratchSrc.Type) if err != nil { - lg.WarnIfFuncError(ss.log, msgCloseScratch, cleanFn) + lg.WarnIfFuncError(gs.log, msgCloseScratch, cleanFn) return nil, err } var backingGrip Grip backingGrip, err = backingDrvr.Open(ctx, scratchSrc) if err != nil { - lg.WarnIfFuncError(ss.log, msgCloseScratch, cleanFn) + lg.WarnIfFuncError(gs.log, msgCloseScratch, cleanFn) return nil, err } @@ -170,14 +170,14 @@ func (ss *Sources) OpenScratch(ctx context.Context, src *source.Source) (Grip, e if !allowCache { // If the ingest cache is disabled, we add the cleanup func // so the scratch DB is deleted when the session ends. - ss.clnup.AddE(cleanFn) + gs.clnup.AddE(cleanFn) } return backingGrip, nil } -// OpenIngest implements driver.IngestOpener. -func (ss *Sources) OpenIngest(ctx context.Context, src *source.Source, allowCache bool, +// OpenIngest implements driver.GripOpenIngester. +func (gs *Grips) OpenIngest(ctx context.Context, src *source.Source, allowCache bool, ingestFn func(ctx context.Context, dest Grip) error, ) (Grip, error) { var grip Grip @@ -185,9 +185,9 @@ func (ss *Sources) OpenIngest(ctx context.Context, src *source.Source, allowCach if !allowCache || src.Handle == source.StdinHandle { // We don't currently cache stdin. Probably we never will? - grip, err = ss.openIngestNoCache(ctx, src, ingestFn) + grip, err = gs.openIngestNoCache(ctx, src, ingestFn) } else { - grip, err = ss.openIngestCache(ctx, src, ingestFn) + grip, err = gs.openIngestCache(ctx, src, ingestFn) } if err != nil { @@ -197,11 +197,11 @@ func (ss *Sources) OpenIngest(ctx context.Context, src *source.Source, allowCach return grip, nil } -func (ss *Sources) openIngestNoCache(ctx context.Context, src *source.Source, +func (gs *Grips) openIngestNoCache(ctx context.Context, src *source.Source, ingestFn func(ctx context.Context, destGrip Grip) error, ) (Grip, error) { log := lg.FromContext(ctx) - impl, err := ss.OpenScratch(ctx, src) + impl, err := gs.OpenScratch(ctx, src) if err != nil { return nil, err } @@ -219,18 +219,18 @@ func (ss *Sources) openIngestNoCache(ctx context.Context, src *source.Source, return nil, err } - ss.log.Info("Ingest completed", + gs.log.Info("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) return impl, nil } -func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, +func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, ingestFn func(ctx context.Context, destGrip Grip) error, ) (Grip, error) { log := lg.FromContext(ctx) - lock, err := ss.acquireLock(ctx, src) + lock, err := gs.acquireLock(ctx, src) if err != nil { return nil, err } @@ -243,7 +243,7 @@ func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, } }() - cacheDir, _, checksumsPath, err := ss.getCachePaths(src) + cacheDir, _, checksumsPath, err := gs.getCachePaths(src) if err != nil { return nil, err } @@ -254,7 +254,7 @@ func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, log.Debug("Using cache dir", lga.Path, cacheDir) - ingestFilePath, err := ss.files.Filepath(ctx, src) + ingestFilePath, err := gs.files.Filepath(ctx, src) if err != nil { return nil, err } @@ -263,7 +263,7 @@ func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, impl Grip foundCached bool ) - if impl, foundCached, err = ss.openCachedFor(ctx, src); err != nil { + if impl, foundCached, err = gs.openCachedFor(ctx, src); err != nil { return nil, err } if foundCached { @@ -275,7 +275,7 @@ func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, log.Debug("Ingest cache MISS: no cache for source", lga.Src, src) - impl, err = ss.OpenScratch(ctx, src) + impl, err = gs.OpenScratch(ctx, src) if err != nil { return nil, err } @@ -314,8 +314,8 @@ func (ss *Sources) openIngestCache(ctx context.Context, src *source.Source, // getCachePaths returns the paths to the cache files for src. // There is no guarantee that these files exist, or are accessible. // It's just the paths. -func (ss *Sources) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checksums string, err error) { - if srcCacheDir, err = ss.files.CacheDirFor(src); err != nil { +func (gs *Grips) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checksums string, err error) { + if srcCacheDir, err = gs.files.CacheDirFor(src); err != nil { return "", "", "", err } @@ -330,8 +330,8 @@ func (ss *Sources) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, chec // defer lg.WarnIfFuncError(d.log, "failed to unlock cache lock", lock.Unlock) // // The lock acquisition process is retried with backoff. -func (ss *Sources) acquireLock(ctx context.Context, src *source.Source) (lockfile.Lockfile, error) { - lock, err := ss.getLockfileFor(src) +func (gs *Grips) acquireLock(ctx context.Context, src *source.Source) (lockfile.Lockfile, error) { + lock, err := gs.getLockfileFor(src) if err != nil { return "", err } @@ -356,8 +356,8 @@ func (ss *Sources) acquireLock(ctx context.Context, src *source.Source) (lockfil // getLockfileFor returns a lockfile for src. It doesn't // actually acquire the lock. -func (ss *Sources) getLockfileFor(src *source.Source) (lockfile.Lockfile, error) { - srcCacheDir, _, _, err := ss.getCachePaths(src) +func (gs *Grips) getLockfileFor(src *source.Source) (lockfile.Lockfile, error) { + srcCacheDir, _, _, err := gs.getCachePaths(src) if err != nil { return "", err } @@ -369,8 +369,8 @@ func (ss *Sources) getLockfileFor(src *source.Source) (lockfile.Lockfile, error) return lockfile.New(lockPath) } -func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Grip, bool, error) { - _, cacheDBPath, checksumsPath, err := ss.getCachePaths(src) +func (gs *Grips) openCachedFor(ctx context.Context, src *source.Source) (Grip, bool, error) { + _, cacheDBPath, checksumsPath, err := gs.getCachePaths(src) if err != nil { return nil, false, err } @@ -384,7 +384,7 @@ func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Grip, return nil, false, err } - drvr, err := ss.drvrs.DriverFor(src.Type) + drvr, err := gs.drvrs.DriverFor(src.Type) if err != nil { return nil, false, err } @@ -395,11 +395,11 @@ func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Grip, } // FIXME: Not too sure invoking files.Filepath here is the right approach? - srcFilepath, err := ss.files.Filepath(ctx, src) + srcFilepath, err := gs.files.Filepath(ctx, src) if err != nil { return nil, false, err } - ss.log.Debug("Got srcFilepath for src", + gs.log.Debug("Got srcFilepath for src", lga.Src, src, lga.Path, srcFilepath) cachedChecksum, ok := mChecksums[srcFilepath] @@ -422,7 +422,7 @@ func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Grip, return nil, false, nil } - backingType, err := ss.files.DriverType(ctx, cacheDBPath) + backingType, err := gs.files.DriverType(ctx, cacheDBPath) if err != nil { return nil, false, err } @@ -433,7 +433,7 @@ func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Grip, Type: backingType, } - backingGrip, err := ss.doOpen(ctx, backingSrc) + backingGrip, err := gs.doOpen(ctx, backingSrc) if err != nil { return nil, false, errz.Wrapf(err, "open cached DB for source %s", src.Handle) } @@ -450,18 +450,18 @@ func (ss *Sources) openCachedFor(ctx context.Context, src *source.Source) (Grip, // location for the join to occur (to minimize copying of data for // the join etc.). Currently the implementation simply delegates // to OpenScratch. -func (ss *Sources) OpenJoin(ctx context.Context, srcs ...*source.Source) (Grip, error) { +func (gs *Grips) OpenJoin(ctx context.Context, srcs ...*source.Source) (Grip, error) { var names []string for _, src := range srcs { names = append(names, src.Handle[1:]) } - ss.log.Debug("OpenJoin", "sources", strings.Join(names, ",")) - return ss.OpenScratch(ctx, srcs[0]) + gs.log.Debug("OpenJoin", "sources", strings.Join(names, ",")) + return gs.OpenScratch(ctx, srcs[0]) } // Close closes d, invoking Close on any instances opened via d.Open. -func (ss *Sources) Close() error { - ss.log.Debug("Closing databases(s)...", lga.Count, ss.clnup.Len()) - return ss.clnup.Run() +func (gs *Grips) Close() error { + gs.log.Debug("Closing databases(s)...", lga.Count, gs.clnup.Len()) + return gs.clnup.Run() } diff --git a/libsq/libsq.go b/libsq/libsq.go index 0e233cf69..7f6477af8 100644 --- a/libsq/libsq.go +++ b/libsq/libsq.go @@ -28,7 +28,7 @@ type QueryContext struct { Collection *source.Collection // Sources bridges between source.Source and databases. - Sources *driver.Sources + Sources *driver.Grips // Args defines variables that are substituted into the query. // May be nil or empty. diff --git a/testh/testh.go b/testh/testh.go index c02cb6a11..fe06273a6 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -97,7 +97,7 @@ type Helper struct { registry *driver.Registry files *source.Files - sources *driver.Sources + grips *driver.Grips run *run.Run initOnce sync.Once @@ -173,20 +173,20 @@ func (h *Helper) init() { h.files.AddDriverDetectors(source.DetectMagicNumber) - h.sources = driver.NewSources(log, h.registry, h.files, sqlite3.NewScratchSource) - h.Cleanup.AddC(h.sources) + h.grips = driver.NewGrips(log, h.registry, h.files, sqlite3.NewScratchSource) + h.Cleanup.AddC(h.grips) h.registry.AddProvider(sqlite3.Type, &sqlite3.Provider{Log: log}) h.registry.AddProvider(postgres.Type, &postgres.Provider{Log: log}) h.registry.AddProvider(sqlserver.Type, &sqlserver.Provider{Log: log}) h.registry.AddProvider(mysql.Type, &mysql.Provider{Log: log}) - csvp := &csv.Provider{Log: log, Ingester: h.sources, Files: h.files} + csvp := &csv.Provider{Log: log, Ingester: h.grips, Files: h.files} h.registry.AddProvider(csv.TypeCSV, csvp) h.registry.AddProvider(csv.TypeTSV, csvp) h.files.AddDriverDetectors(csv.DetectCSV, csv.DetectTSV) - jsonp := &json.Provider{Log: log, Ingester: h.sources, Files: h.files} + jsonp := &json.Provider{Log: log, Ingester: h.grips, Files: h.files} h.registry.AddProvider(json.TypeJSON, jsonp) h.registry.AddProvider(json.TypeJSONA, jsonp) h.registry.AddProvider(json.TypeJSONL, jsonp) @@ -196,7 +196,7 @@ func (h *Helper) init() { json.DetectJSONL(driver.OptIngestSampleSize.Get(nil)), ) - h.registry.AddProvider(xlsx.Type, &xlsx.Provider{Log: log, Ingester: h.sources, Files: h.files}) + h.registry.AddProvider(xlsx.Type, &xlsx.Provider{Log: log, Ingester: h.grips, Files: h.files}) h.files.AddDriverDetectors(xlsx.DetectXLSX) h.addUserDrivers() @@ -627,7 +627,7 @@ func (h *Helper) QuerySLQ(query string, args map[string]string) (*RecordSink, er qc := &libsq.QueryContext{ Collection: h.coll, - Sources: h.sources, + Sources: h.grips, Args: args, } @@ -735,7 +735,7 @@ func (h *Helper) addUserDrivers() { Log: h.Log, DriverDef: userDriverDef, ImportFn: importFn, - Ingester: h.sources, + Ingester: h.grips, Files: h.files, } @@ -749,10 +749,10 @@ func (h *Helper) IsMonotable(src *source.Source) bool { return h.DriverFor(src).DriverMetadata().Monotable } -// Sources returns the helper's driver.Sources instance. -func (h *Helper) Sources() *driver.Sources { +// Sources returns the helper's driver.Grips instance. +func (h *Helper) Sources() *driver.Grips { h.init() - return h.sources + return h.grips } // Files returns the helper's Files instance. From 478676cbbe11d59decb31a76be0bd1f3080ae5f1 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 3 Dec 2023 23:26:21 -0700 Subject: [PATCH 052/195] Refactor: moved driver.Driver.Truncate -> driver.SQLDriver.Truncate --- cli/cmd_tbl.go | 7 +- drivers/csv/csv.go | 5 -- drivers/json/json.go | 5 -- drivers/userdriver/userdriver.go | 5 -- drivers/xlsx/xlsx.go | 10 --- libsq/driver/driver.go | 142 ++----------------------------- libsq/driver/opts.go | 137 +++++++++++++++++++++++++++++ testh/testh.go | 5 +- 8 files changed, 153 insertions(+), 163 deletions(-) create mode 100644 libsq/driver/opts.go diff --git a/cli/cmd_tbl.go b/cli/cmd_tbl.go index e1f012652..3ca605507 100644 --- a/cli/cmd_tbl.go +++ b/cli/cmd_tbl.go @@ -199,7 +199,12 @@ func execTblTruncate(cmd *cobra.Command, args []string) (err error) { for _, tblH := range tblHandles { var affected int64 - affected, err = tblH.drvr.Truncate(cmd.Context(), tblH.src, tblH.tbl, true) + if !tblH.drvr.DriverMetadata().IsSQL { + return errz.Errorf("driver {%s} for source %s doesn't support truncate", + tblH.drvr.DriverMetadata().Type, tblH.src.Handle) + } + + affected, err = tblH.drvr.(driver.SQLDriver).Truncate(cmd.Context(), tblH.src, tblH.tbl, true) if err != nil { return err } diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index 8f2a1b270..0b985e862 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -92,11 +92,6 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, er return g, nil } -// Truncate implements driver.Driver. -func (d *driveri) Truncate(_ context.Context, _ *source.Source, _ string, _ bool) (int64, error) { - return 0, errz.Errorf("truncate not supported for %s", d.DriverMetadata().Type) -} - // ValidateSource implements driver.Driver. func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { if src.Type != d.typ { diff --git a/drivers/json/json.go b/drivers/json/json.go index 0b1971877..7db4decf2 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -124,11 +124,6 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, er return g, nil } -// Truncate implements driver.Driver. -func (d *driveri) Truncate(_ context.Context, _ *source.Source, _ string, _ bool) (int64, error) { - return 0, errz.Errorf("truncate not supported for %s", d.DriverMetadata().Type) -} - // ValidateSource implements driver.Driver. func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { if src.Type != d.typ { diff --git a/drivers/userdriver/userdriver.go b/drivers/userdriver/userdriver.go index 7128b4487..3b029f93d 100644 --- a/drivers/userdriver/userdriver.go +++ b/drivers/userdriver/userdriver.go @@ -108,11 +108,6 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, er return g, nil } -// Truncate implements driver.Driver. -func (d *driveri) Truncate(_ context.Context, _ *source.Source, _ string, _ bool) (int64, error) { - return 0, errz.Errorf("truncate not supported for %s", d.DriverMetadata().Type) -} - // ValidateSource implements driver.Driver. func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { d.log.Debug("Validating source", lga.Src, src) diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index 59666ac73..180fd7865 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -99,16 +99,6 @@ func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Grip, err return p, nil } -// Truncate implements driver.Driver. -func (d *Driver) Truncate(_ context.Context, src *source.Source, _ string, _ bool) (affected int64, err error) { - // NOTE: We could actually implement Truncate for xlsx. - // It would just mean deleting the rows from a sheet, and then - // saving the sheet. But that's probably not a game we want to - // get into, as sq doesn't currently make edits to any non-SQL - // source types. - return 0, errz.Errorf("driver type {%s} (%s) doesn't support dropping tables", Type, src.Handle) -} - // ValidateSource implements driver.Driver. func (d *Driver) ValidateSource(src *source.Source) (*source.Source, error) { d.log.Debug("Validating source", lga.Src, src) diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index e3cf69910..0bc1f5a56 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -3,7 +3,6 @@ package driver import ( "context" "database/sql" - "time" "github.com/neilotoole/sq/libsq/ast/render" "github.com/neilotoole/sq/libsq/core/errz" @@ -22,133 +21,6 @@ import ( "github.com/neilotoole/sq/libsq/source/metadata" ) -// ConfigureDB configures DB using o. It is no-op if o is nil. -func ConfigureDB(ctx context.Context, db *sql.DB, o options.Options) { - o2 := options.Effective(o, OptConnMaxOpen, OptConnMaxIdle, OptConnMaxIdleTime, OptConnMaxLifetime) - - lg.FromContext(ctx).Debug("Setting config on DB conn", "config", o2) - - db.SetMaxOpenConns(OptConnMaxOpen.Get(o2)) - db.SetMaxIdleConns(OptConnMaxIdle.Get(o2)) - db.SetConnMaxIdleTime(OptConnMaxIdleTime.Get(o2)) - db.SetConnMaxLifetime(OptConnMaxLifetime.Get(o2)) -} - -var ( - // OptConnMaxOpen controls sql.DB.SetMaxOpenConn. - OptConnMaxOpen = options.NewInt( - "conn.max-open", - "", - 0, - 0, - "Max open connections to DB", - `Maximum number of open connections to the database. -A value of zero indicates no limit.`, - options.TagSource, - options.TagSQL, - ) - - // OptConnMaxIdle controls sql.DB.SetMaxIdleConns. - OptConnMaxIdle = options.NewInt( - "conn.max-idle", - "", - 0, - 2, - "Max connections in idle connection pool", - `Set the maximum number of connections in the idle connection pool. -If conn.max-open is greater than 0 but less than the new conn.max-idle, -then the new conn.max-idle will be reduced to match the conn.max-open limit. -If n <= 0, no idle connections are retained.`, - options.TagSource, - options.TagSQL, - ) - - // OptConnMaxIdleTime controls sql.DB.SetConnMaxIdleTime. - OptConnMaxIdleTime = options.NewDuration( - "conn.max-idle-time", - "", - 0, - time.Second*2, - "Max connection idle time", - `Sets the maximum amount of time a connection may be idle. -Expired connections may be closed lazily before reuse. If n <= 0, -connections are not closed due to a connection's idle time.`, - options.TagSource, - options.TagSQL, - ) - - // OptConnMaxLifetime controls sql.DB.SetConnMaxLifetime. - OptConnMaxLifetime = options.NewDuration( - "conn.max-lifetime", - "", - 0, - time.Minute*10, - "Max connection lifetime", - `Set the maximum amount of time a connection may be reused. -Expired connections may be closed lazily before reuse. -If n <= 0, connections are not closed due to a connection's age.`, - options.TagSource, - options.TagSQL, - ) - - // OptConnOpenTimeout controls connection open timeout. - OptConnOpenTimeout = options.NewDuration( - "conn.open-timeout", - "", - 0, - time.Second*5, - "Connection open timeout", - "Max time to wait before a connection open timeout occurs.", - options.TagSource, - options.TagSQL, - ) - - // OptMaxRetryInterval is the maximum interval to wait - // between retries. - OptMaxRetryInterval = options.NewDuration( - "retry.max-interval", - "", - 0, - time.Second*3, - "Max interval between retries", - `The maximum interval to wait between retries. -If an operation is retryable (for example, if the DB has too many clients), -repeated retry operations back off, typically using a Fibonacci backoff.`, - options.TagSource, - ) - - // OptTuningErrgroupLimit controls the maximum number of goroutines that can be spawned - // by an errgroup. - OptTuningErrgroupLimit = options.NewInt( - "tuning.errgroup-limit", - "", - 0, - 16, - "Max goroutines in any one errgroup", - `Controls the maximum number of goroutines that can be spawned -by an errgroup. Note that this is the limit for any one errgroup, but not a -ceiling on the total number of goroutines spawned, as some errgroups may -themselves start an errgroup. - -This knob is primarily for internal use. Ultimately it should go away -in favor of dynamic errgroup limit setting based on availability -of additional DB conns, etc.`, - options.TagTuning, - ) - - // OptTuningRecChanSize is the size of the buffer chan for record - // insertion/writing. - OptTuningRecChanSize = options.NewInt( - "tuning.record-buffer", - "", - 0, - 1024, - "Size of record buffer", - `Controls the size of the buffer channel for record insertion/writing.`, - options.TagTuning, - ) -) - // Provider is a factory that returns Driver instances. type Provider interface { // DriverFor returns a driver instance for the given type. @@ -172,14 +44,6 @@ type Driver interface { // Ping verifies that the source is reachable, or returns an error if not. // The exact behavior of Ping() is driver-dependent. Ping(ctx context.Context, src *source.Source) error - - // Truncate truncates tbl in src. If arg reset is true, the - // identity counter for tbl should be reset, if supported - // by the driver. Some DB impls may reset the identity - // counter regardless of the val of reset. - // - // TODO: Maybe move Truncate to SQLDriver? - Truncate(ctx context.Context, src *source.Source, tbl string, reset bool) (affected int64, err error) } // SQLDriver is implemented by Driver instances for SQL databases. @@ -282,6 +146,12 @@ type SQLDriver interface { // DropSchema drops the named schema in db. DropSchema(ctx context.Context, db sqlz.DB, schemaName string) error + // Truncate truncates tbl in src. If arg reset is true, the + // identity counter for tbl should be reset, if supported + // by the driver. Some DB impls may reset the identity + // counter regardless of the val of reset. + Truncate(ctx context.Context, src *source.Source, tbl string, reset bool) (affected int64, err error) + // TableExists returns true if there's an existing table tbl in db. TableExists(ctx context.Context, db sqlz.DB, tbl string) (bool, error) diff --git a/libsq/driver/opts.go b/libsq/driver/opts.go new file mode 100644 index 000000000..a35b58fd2 --- /dev/null +++ b/libsq/driver/opts.go @@ -0,0 +1,137 @@ +package driver + +import ( + "context" + "database/sql" + "time" + + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/options" +) + +// ConfigureDB configures DB using o. It is no-op if o is nil. +func ConfigureDB(ctx context.Context, db *sql.DB, o options.Options) { + o2 := options.Effective(o, OptConnMaxOpen, OptConnMaxIdle, OptConnMaxIdleTime, OptConnMaxLifetime) + + lg.FromContext(ctx).Debug("Setting config on DB conn", "config", o2) + + db.SetMaxOpenConns(OptConnMaxOpen.Get(o2)) + db.SetMaxIdleConns(OptConnMaxIdle.Get(o2)) + db.SetConnMaxIdleTime(OptConnMaxIdleTime.Get(o2)) + db.SetConnMaxLifetime(OptConnMaxLifetime.Get(o2)) +} + +var ( + // OptConnMaxOpen controls sql.DB.SetMaxOpenConn. + OptConnMaxOpen = options.NewInt( + "conn.max-open", + "", + 0, + 0, + "Max open connections to DB", + `Maximum number of open connections to the database. +A value of zero indicates no limit.`, + options.TagSource, + options.TagSQL, + ) + + // OptConnMaxIdle controls sql.DB.SetMaxIdleConns. + OptConnMaxIdle = options.NewInt( + "conn.max-idle", + "", + 0, + 2, + "Max connections in idle connection pool", + `Set the maximum number of connections in the idle connection pool. +If conn.max-open is greater than 0 but less than the new conn.max-idle, +then the new conn.max-idle will be reduced to match the conn.max-open limit. +If n <= 0, no idle connections are retained.`, + options.TagSource, + options.TagSQL, + ) + + // OptConnMaxIdleTime controls sql.DB.SetConnMaxIdleTime. + OptConnMaxIdleTime = options.NewDuration( + "conn.max-idle-time", + "", + 0, + time.Second*2, + "Max connection idle time", + `Sets the maximum amount of time a connection may be idle. +Expired connections may be closed lazily before reuse. If n <= 0, +connections are not closed due to a connection's idle time.`, + options.TagSource, + options.TagSQL, + ) + + // OptConnMaxLifetime controls sql.DB.SetConnMaxLifetime. + OptConnMaxLifetime = options.NewDuration( + "conn.max-lifetime", + "", + 0, + time.Minute*10, + "Max connection lifetime", + `Set the maximum amount of time a connection may be reused. +Expired connections may be closed lazily before reuse. +If n <= 0, connections are not closed due to a connection's age.`, + options.TagSource, + options.TagSQL, + ) + + // OptConnOpenTimeout controls connection open timeout. + OptConnOpenTimeout = options.NewDuration( + "conn.open-timeout", + "", + 0, + time.Second*5, + "Connection open timeout", + "Max time to wait before a connection open timeout occurs.", + options.TagSource, + options.TagSQL, + ) + + // OptMaxRetryInterval is the maximum interval to wait + // between retries. + OptMaxRetryInterval = options.NewDuration( + "retry.max-interval", + "", + 0, + time.Second*3, + "Max interval between retries", + `The maximum interval to wait between retries. +If an operation is retryable (for example, if the DB has too many clients), +repeated retry operations back off, typically using a Fibonacci backoff.`, + options.TagSource, + ) + + // OptTuningErrgroupLimit controls the maximum number of goroutines that can be spawned + // by an errgroup. + OptTuningErrgroupLimit = options.NewInt( + "tuning.errgroup-limit", + "", + 0, + 16, + "Max goroutines in any one errgroup", + `Controls the maximum number of goroutines that can be spawned +by an errgroup. Note that this is the limit for any one errgroup, but not a +ceiling on the total number of goroutines spawned, as some errgroups may +themselves start an errgroup. + +This knob is primarily for internal use. Ultimately it should go away +in favor of dynamic errgroup limit setting based on availability +of additional DB conns, etc.`, + options.TagTuning, + ) + + // OptTuningRecChanSize is the size of the buffer chan for record + // insertion/writing. + OptTuningRecChanSize = options.NewInt( + "tuning.record-buffer", + "", + 0, + 1024, + "Size of record buffer", + `Controls the size of the buffer channel for record insertion/writing.`, + options.TagTuning, + ) +) diff --git a/testh/testh.go b/testh/testh.go index fe06273a6..43be46676 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -692,7 +692,10 @@ func (h *Helper) TruncateTable(src *source.Source, tbl string) (affected int64) grip := h.openNew(src) defer lg.WarnIfCloseError(h.Log, lgm.CloseDB, grip) - affected, err := h.DriverFor(src).Truncate(h.Context, src, tbl, true) + drvr := h.SQLDriverFor(src) + require.NotNil(h.T, drvr, "not a SQL driver") + + affected, err := drvr.Truncate(h.Context, src, tbl, true) require.NoError(h.T, err) return affected } From c9c2552337b944708be9c8260096cf719c1d4ac9 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Mon, 4 Dec 2023 00:44:49 -0700 Subject: [PATCH 053/195] Added ctx to RecordWriter methods --- cli/diff/record.go | 20 +++++++------ cli/output/adapter.go | 12 ++++---- cli/output/csvw/csvw.go | 16 +++++++---- cli/output/csvw/csvw_test.go | 8 ++++-- cli/output/htmlw/htmlw.go | 16 +++++++---- cli/output/htmlw/htmlw_test.go | 8 ++++-- cli/output/jsonw/jsonw_test.go | 8 ++++-- cli/output/jsonw/recordwriter.go | 32 ++++++++++++--------- cli/output/markdownw/markdownw.go | 16 +++++++---- cli/output/markdownw/markdownw_test.go | 8 ++++-- cli/output/raww/raww.go | 14 +++++++--- cli/output/raww/raww_test.go | 15 ++++++---- cli/output/tablew/configwriter.go | 19 ++++++------- cli/output/tablew/internal/texttable.go | 16 +++++++++-- cli/output/tablew/metadatawriter.go | 37 ++++++++++--------------- cli/output/tablew/recordwriter.go | 20 +++++++------ cli/output/tablew/sourcewriter.go | 16 ++++------- cli/output/tablew/tablew.go | 26 ++++++++++++----- cli/output/writers.go | 9 +++--- cli/output/xlsxw/xlsxw.go | 14 +++++++--- cli/output/xlsxw/xlsxw_test.go | 15 ++++++---- cli/output/xmlw/xmlw.go | 17 ++++++------ cli/output/xmlw/xmlw_test.go | 16 +++++++---- cli/output/yamlw/recordwriter.go | 16 +++++++---- cli/output/yamlw/yamlw_test.go | 8 +++--- testh/record.go | 9 +++--- 26 files changed, 243 insertions(+), 168 deletions(-) diff --git a/cli/diff/record.go b/cli/diff/record.go index 6501fc667..1a88ee5c1 100644 --- a/cli/diff/record.go +++ b/cli/diff/record.go @@ -158,7 +158,7 @@ func findRecordDiff(ctx context.Context, ru *run.Run, lines int, row: i, } - if err = populateRecordDiff(lines, ru.Writers.Printing, recDiff); err != nil { + if err = populateRecordDiff(ctx, lines, ru.Writers.Printing, recDiff); err != nil { return nil, err } @@ -166,7 +166,7 @@ func findRecordDiff(ctx context.Context, ru *run.Run, lines int, } //nolint:unused -func populateRecordDiff(lines int, pr *output.Printing, recDiff *recordDiff) error { +func populateRecordDiff(ctx context.Context, lines int, pr *output.Printing, recDiff *recordDiff) error { pr = pr.Clone() pr.EnableColor(false) @@ -178,10 +178,10 @@ func populateRecordDiff(lines int, pr *output.Printing, recDiff *recordDiff) err err error ) - if body1, err = renderRecord2YAML(pr, recDiff.recMeta1, recDiff.rec1); err != nil { + if body1, err = renderRecord2YAML(ctx, pr, recDiff.recMeta1, recDiff.rec1); err != nil { return err } - if body2, err = renderRecord2YAML(pr, recDiff.recMeta1, recDiff.rec2); err != nil { + if body2, err = renderRecord2YAML(ctx, pr, recDiff.recMeta1, recDiff.rec2); err != nil { return err } @@ -204,23 +204,25 @@ func populateRecordDiff(lines int, pr *output.Printing, recDiff *recordDiff) err } //nolint:unused -func renderRecord2YAML(pr *output.Printing, recMeta record.Meta, rec record.Record) (string, error) { +func renderRecord2YAML(ctx context.Context, pr *output.Printing, + recMeta record.Meta, rec record.Record, +) (string, error) { if rec == nil { return "", nil } buf := &bytes.Buffer{} yw := yamlw.NewRecordWriter(buf, pr) - if err := yw.Open(recMeta); err != nil { + if err := yw.Open(ctx, recMeta); err != nil { return "", err } - if err := yw.WriteRecords([]record.Record{rec}); err != nil { + if err := yw.WriteRecords(ctx, []record.Record{rec}); err != nil { return "", err } - if err := yw.Flush(); err != nil { + if err := yw.Flush(ctx); err != nil { return "", err } - if err := yw.Close(); err != nil { + if err := yw.Close(ctx); err != nil { return "", err } return buf.String(), nil diff --git a/cli/output/adapter.go b/cli/output/adapter.go index a810aa644..fa62f5ebf 100644 --- a/cli/output/adapter.go +++ b/cli/output/adapter.go @@ -63,7 +63,7 @@ func (w *RecordWriterAdapter) Open(ctx context.Context, cancelFn context.CancelF lg.FromContext(ctx).Debug("Open RecordWriterAdapter", "fields", recMeta) - err := w.rw.Open(recMeta) + err := w.rw.Open(ctx, recMeta) if err != nil { return nil, nil, err } @@ -94,12 +94,12 @@ func (w *RecordWriterAdapter) Open(ctx context.Context, cancelFn context.CancelF for { select { case <-ctx.Done(): - w.addErrs(ctx.Err(), w.rw.Close()) + w.addErrs(ctx.Err(), w.rw.Close(ctx)) return case <-flushCh: // The flushTimer has expired, time to flush. - err = w.rw.Flush() + err = w.rw.Flush(ctx) if err != nil { w.addErrs(err) return @@ -110,7 +110,7 @@ func (w *RecordWriterAdapter) Open(ctx context.Context, cancelFn context.CancelF case rec := <-w.recCh: if rec == nil { // no more results on recCh, it has been closed - err = w.rw.Close() + err = w.rw.Close(ctx) if err != nil { w.addErrs() } @@ -121,7 +121,7 @@ func (w *RecordWriterAdapter) Open(ctx context.Context, cancelFn context.CancelF // We could accumulate a bunch of recs into a slice here, // but we'll worry about that if benchmarking shows it'll matter. - writeErr := w.rw.WriteRecords([]record.Record{rec}) + writeErr := w.rw.WriteRecords(ctx, []record.Record{rec}) if writeErr != nil { w.addErrs(writeErr) return @@ -131,7 +131,7 @@ func (w *RecordWriterAdapter) Open(ctx context.Context, cancelFn context.CancelF // Check if we should flush if w.FlushAfterN >= 0 && (recN-lastFlushN >= w.FlushAfterN) { - err = w.rw.Flush() + err = w.rw.Flush(ctx) if err != nil { w.addErrs(err) return diff --git a/cli/output/csvw/csvw.go b/cli/output/csvw/csvw.go index a47c52d82..636542887 100644 --- a/cli/output/csvw/csvw.go +++ b/cli/output/csvw/csvw.go @@ -2,6 +2,7 @@ package csvw import ( + "context" "encoding/csv" "fmt" "io" @@ -63,25 +64,25 @@ func (w *RecordWriter) SetComma(c rune) { } // Open implements output.RecordWriter. -func (w *RecordWriter) Open(recMeta record.Meta) error { +func (w *RecordWriter) Open(_ context.Context, recMeta record.Meta) error { w.recMeta = recMeta return nil } // Flush implements output.RecordWriter. -func (w *RecordWriter) Flush() error { +func (w *RecordWriter) Flush(context.Context) error { w.cw.Flush() return nil } // Close implements output.RecordWriter. -func (w *RecordWriter) Close() error { +func (w *RecordWriter) Close(context.Context) error { w.cw.Flush() return w.cw.Error() } // WriteRecords implements output.RecordWriter. -func (w *RecordWriter) WriteRecords(recs []record.Record) error { +func (w *RecordWriter) WriteRecords(ctx context.Context, recs []record.Record) error { w.mu.Lock() defer w.mu.Unlock() @@ -98,6 +99,11 @@ func (w *RecordWriter) WriteRecords(recs []record.Record) error { } for _, rec := range recs { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } fields := make([]string, len(rec)) for i, val := range rec { @@ -142,5 +148,5 @@ func (w *RecordWriter) WriteRecords(recs []record.Record) error { } w.cw.Flush() - return w.cw.Error() + return errz.Err(w.cw.Error()) } diff --git a/cli/output/csvw/csvw_test.go b/cli/output/csvw/csvw_test.go index ba219623a..79c0fbbc8 100644 --- a/cli/output/csvw/csvw_test.go +++ b/cli/output/csvw/csvw_test.go @@ -2,6 +2,7 @@ package csvw_test import ( "bytes" + "context" "testing" "time" @@ -15,6 +16,7 @@ import ( ) func TestDateTimeHandling(t *testing.T) { + ctx := context.Background() var ( colNames = []string{"col_datetime", "col_date", "col_time"} kinds = []kind.Kind{kind.Datetime, kind.Date, kind.Time} @@ -30,11 +32,11 @@ func TestDateTimeHandling(t *testing.T) { pr.EnableColor(false) w := csvw.NewTabRecordWriter(buf, pr) - require.NoError(t, w.Open(recMeta)) + require.NoError(t, w.Open(ctx, recMeta)) rec := record.Record{when, when, when} - require.NoError(t, w.WriteRecords([]record.Record{rec})) - require.NoError(t, w.Close()) + require.NoError(t, w.WriteRecords(ctx, []record.Record{rec})) + require.NoError(t, w.Close(ctx)) require.Equal(t, want, buf.String()) t.Log(buf.String()) diff --git a/cli/output/htmlw/htmlw.go b/cli/output/htmlw/htmlw.go index 47b4e7562..783d45675 100644 --- a/cli/output/htmlw/htmlw.go +++ b/cli/output/htmlw/htmlw.go @@ -3,6 +3,7 @@ package htmlw import ( "bytes" + "context" "encoding/base64" "fmt" "html" @@ -37,7 +38,7 @@ func NewRecordWriter(out io.Writer, pr *output.Printing) output.RecordWriter { } // Open implements output.RecordWriter. -func (w *recordWriter) Open(recMeta record.Meta) error { +func (w *recordWriter) Open(_ context.Context, recMeta record.Meta) error { w.recMeta = recMeta w.buf = &bytes.Buffer{} @@ -69,7 +70,7 @@ func (w *recordWriter) Open(recMeta record.Meta) error { } // Flush implements output.RecordWriter. -func (w *recordWriter) Flush() error { +func (w *recordWriter) Flush(context.Context) error { w.mu.Lock() defer w.mu.Unlock() _, err := w.buf.WriteTo(w.out) // resets buf @@ -77,8 +78,8 @@ func (w *recordWriter) Flush() error { } // Close implements output.RecordWriter. -func (w *recordWriter) Close() error { - err := w.Flush() +func (w *recordWriter) Close(ctx context.Context) error { + err := w.Flush(ctx) if err != nil { return err } @@ -140,11 +141,16 @@ func (w *recordWriter) writeRecord(rec record.Record) error { } // WriteRecords implements output.RecordWriter. -func (w *recordWriter) WriteRecords(recs []record.Record) error { +func (w *recordWriter) WriteRecords(ctx context.Context, recs []record.Record) error { w.mu.Lock() defer w.mu.Unlock() var err error for _, rec := range recs { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } err = w.writeRecord(rec) if err != nil { return err diff --git a/cli/output/htmlw/htmlw_test.go b/cli/output/htmlw/htmlw_test.go index d27dfc09d..a71993366 100644 --- a/cli/output/htmlw/htmlw_test.go +++ b/cli/output/htmlw/htmlw_test.go @@ -2,6 +2,7 @@ package htmlw_test import ( "bytes" + "context" "os" "testing" @@ -27,16 +28,17 @@ func TestRecordWriter(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() recMeta, recs := testh.RecordsFromTbl(t, sakila.SL3, sakila.TblActor) recs = recs[0:tc.numRecs] buf := &bytes.Buffer{} pr := output.NewPrinting() w := htmlw.NewRecordWriter(buf, pr) - require.NoError(t, w.Open(recMeta)) + require.NoError(t, w.Open(ctx, recMeta)) - require.NoError(t, w.WriteRecords(recs)) - require.NoError(t, w.Close()) + require.NoError(t, w.WriteRecords(ctx, recs)) + require.NoError(t, w.Close(ctx)) want, err := os.ReadFile(tc.fixtPath) require.NoError(t, err) diff --git a/cli/output/jsonw/jsonw_test.go b/cli/output/jsonw/jsonw_test.go index 16fb1e2df..79ef9f12f 100644 --- a/cli/output/jsonw/jsonw_test.go +++ b/cli/output/jsonw/jsonw_test.go @@ -3,6 +3,7 @@ package jsonw_test import ( "bytes" + "context" "encoding/json" "io" "strings" @@ -139,6 +140,7 @@ func TestRecordWriters(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() colNames, kinds := fixt.ColNamePerKind(false, false, false) recMeta := testh.NewRecordMeta(colNames, kinds) @@ -164,9 +166,9 @@ func TestRecordWriters(t *testing.T) { w := tc.factoryFn(buf, pr) - require.NoError(t, w.Open(recMeta)) - require.NoError(t, w.WriteRecords(recs)) - require.NoError(t, w.Close()) + require.NoError(t, w.Open(ctx, recMeta)) + require.NoError(t, w.WriteRecords(ctx, recs)) + require.NoError(t, w.Close(ctx)) require.Equal(t, tc.want, buf.String()) if !tc.multiline { diff --git a/cli/output/jsonw/recordwriter.go b/cli/output/jsonw/recordwriter.go index 49d676f59..43e01ee76 100644 --- a/cli/output/jsonw/recordwriter.go +++ b/cli/output/jsonw/recordwriter.go @@ -2,6 +2,7 @@ package jsonw import ( "bytes" + "context" "io" "strings" "sync" @@ -60,7 +61,7 @@ type stdWriter struct { } // Open implements output.RecordWriter. -func (w *stdWriter) Open(recMeta record.Meta) error { +func (w *stdWriter) Open(_ context.Context, recMeta record.Meta) error { if w.err != nil { return w.err } @@ -92,7 +93,7 @@ func (w *stdWriter) Open(recMeta record.Meta) error { } // WriteRecords implements output.RecordWriter. -func (w *stdWriter) WriteRecords(recs []record.Record) error { +func (w *stdWriter) WriteRecords(_ context.Context, recs []record.Record) error { w.mu.Lock() defer w.mu.Unlock() if w.err != nil { @@ -141,7 +142,7 @@ func (w *stdWriter) writeRecord(rec record.Record) error { } // Flush implements output.RecordWriter. -func (w *stdWriter) Flush() error { +func (w *stdWriter) Flush(context.Context) error { w.mu.Lock() defer w.mu.Unlock() return w.doFlush() @@ -158,7 +159,7 @@ func (w *stdWriter) doFlush() error { } // Close implements output.RecordWriter. -func (w *stdWriter) Close() error { +func (w *stdWriter) Close(ctx context.Context) error { if w.err != nil { return w.err } @@ -168,7 +169,7 @@ func (w *stdWriter) Close() error { } w.outBuf.Write(w.tpl.footer) - return w.Flush() + return w.Flush(ctx) } // stdTemplate holds the various parts of the output template @@ -314,7 +315,7 @@ type lineRecordWriter struct { } // Open implements output.RecordWriter. -func (w *lineRecordWriter) Open(recMeta record.Meta) error { +func (w *lineRecordWriter) Open(_ context.Context, recMeta record.Meta) error { if w.err != nil { return w.err } @@ -331,7 +332,7 @@ func (w *lineRecordWriter) Open(recMeta record.Meta) error { } // WriteRecords implements output.RecordWriter. -func (w *lineRecordWriter) WriteRecords(recs []record.Record) error { +func (w *lineRecordWriter) WriteRecords(ctx context.Context, recs []record.Record) error { w.mu.Lock() defer w.mu.Unlock() if w.err != nil { @@ -340,7 +341,12 @@ func (w *lineRecordWriter) WriteRecords(recs []record.Record) error { var err error for i := range recs { - err = w.writeRecord(recs[i]) + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + err = w.writeRecord(ctx, recs[i]) if err != nil { w.err = err return err @@ -349,7 +355,7 @@ func (w *lineRecordWriter) WriteRecords(recs []record.Record) error { return nil } -func (w *lineRecordWriter) writeRecord(rec record.Record) error { +func (w *lineRecordWriter) writeRecord(ctx context.Context, rec record.Record) error { var err error b := make([]byte, 0, 10) @@ -365,14 +371,14 @@ func (w *lineRecordWriter) writeRecord(rec record.Record) error { w.outBuf.Write(b) if w.outBuf.Len() > w.pr.FlushThreshold { - return w.Flush() + return w.Flush(ctx) } return nil } // Flush implements output.RecordWriter. -func (w *lineRecordWriter) Flush() error { +func (w *lineRecordWriter) Flush(context.Context) error { if w.err != nil { return w.err } @@ -384,12 +390,12 @@ func (w *lineRecordWriter) Flush() error { } // Close implements output.RecordWriter. -func (w *lineRecordWriter) Close() error { +func (w *lineRecordWriter) Close(ctx context.Context) error { if w.err != nil { return w.err } - return w.Flush() + return w.Flush(ctx) } func newJSONObjectsTemplate(recMeta record.Meta, pr *output.Printing) ([][]byte, error) { diff --git a/cli/output/markdownw/markdownw.go b/cli/output/markdownw/markdownw.go index 17231f278..14c9b7100 100644 --- a/cli/output/markdownw/markdownw.go +++ b/cli/output/markdownw/markdownw.go @@ -3,6 +3,7 @@ package markdownw import ( "bytes" + "context" "encoding/base64" "fmt" "html" @@ -37,7 +38,7 @@ func NewRecordWriter(out io.Writer, pr *output.Printing) output.RecordWriter { } // Open implements output.RecordWriter. -func (w *RecordWriter) Open(recMeta record.Meta) error { +func (w *RecordWriter) Open(_ context.Context, recMeta record.Meta) error { w.recMeta = recMeta w.buf = &bytes.Buffer{} @@ -64,14 +65,14 @@ func (w *RecordWriter) Open(recMeta record.Meta) error { } // Flush implements output.RecordWriter. -func (w *RecordWriter) Flush() error { +func (w *RecordWriter) Flush(context.Context) error { _, err := w.buf.WriteTo(w.out) // resets buf return err } // Close implements output.RecordWriter. -func (w *RecordWriter) Close() error { - return w.Flush() +func (w *RecordWriter) Close(ctx context.Context) error { + return w.Flush(ctx) } func (w *RecordWriter) writeRecord(rec record.Record) error { @@ -117,12 +118,17 @@ func (w *RecordWriter) writeRecord(rec record.Record) error { } // WriteRecords implements output.RecordWriter. -func (w *RecordWriter) WriteRecords(recs []record.Record) error { +func (w *RecordWriter) WriteRecords(ctx context.Context, recs []record.Record) error { w.mu.Lock() defer w.mu.Unlock() var err error for _, rec := range recs { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } err = w.writeRecord(rec) if err != nil { return err diff --git a/cli/output/markdownw/markdownw_test.go b/cli/output/markdownw/markdownw_test.go index ac1ad7bc0..c3d69ba85 100644 --- a/cli/output/markdownw/markdownw_test.go +++ b/cli/output/markdownw/markdownw_test.go @@ -2,6 +2,7 @@ package markdownw_test import ( "bytes" + "context" "testing" "github.com/stretchr/testify/require" @@ -38,15 +39,16 @@ func TestRecordWriter(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() recMeta, recs := testh.RecordsFromTbl(t, sakila.SL3, sakila.TblActor) recs = recs[0:tc.numRecs] buf := &bytes.Buffer{} w := markdownw.NewRecordWriter(buf, output.NewPrinting()) - require.NoError(t, w.Open(recMeta)) + require.NoError(t, w.Open(ctx, recMeta)) - require.NoError(t, w.WriteRecords(recs)) - require.NoError(t, w.Close()) + require.NoError(t, w.WriteRecords(ctx, recs)) + require.NoError(t, w.Close(ctx)) require.Equal(t, tc.want, buf.String()) }) } diff --git a/cli/output/raww/raww.go b/cli/output/raww/raww.go index 7cea6dea5..f4217f22a 100644 --- a/cli/output/raww/raww.go +++ b/cli/output/raww/raww.go @@ -1,6 +1,7 @@ package raww import ( + "context" "fmt" "io" "strconv" @@ -39,13 +40,13 @@ func NewRecordWriter(out io.Writer, pr *output.Printing) output.RecordWriter { } // Open implements output.RecordWriter. -func (w *recordWriter) Open(recMeta record.Meta) error { +func (w *recordWriter) Open(_ context.Context, recMeta record.Meta) error { w.recMeta = recMeta return nil } // WriteRecords implements output.RecordWriter. -func (w *recordWriter) WriteRecords(recs []record.Record) error { +func (w *recordWriter) WriteRecords(ctx context.Context, recs []record.Record) error { w.mu.Lock() defer w.mu.Unlock() @@ -54,6 +55,11 @@ func (w *recordWriter) WriteRecords(recs []record.Record) error { } for _, rec := range recs { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } for i, val := range rec { switch val := val.(type) { case nil: @@ -89,11 +95,11 @@ func (w *recordWriter) WriteRecords(recs []record.Record) error { } // Flush implements output.RecordWriter. -func (w *recordWriter) Flush() error { +func (w *recordWriter) Flush(context.Context) error { return nil } // Close implements output.RecordWriter. -func (w *recordWriter) Close() error { +func (w *recordWriter) Close(context.Context) error { return nil } diff --git a/cli/output/raww/raww_test.go b/cli/output/raww/raww_test.go index 77ad2f559..676606788 100644 --- a/cli/output/raww/raww_test.go +++ b/cli/output/raww/raww_test.go @@ -2,6 +2,7 @@ package raww_test import ( "bytes" + "context" "image/gif" "testing" @@ -34,14 +35,16 @@ func TestRecordWriter_TblActor(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + recMeta, recs := testh.RecordsFromTbl(t, sakila.SL3, sakila.TblActor) recs = recs[0:tc.numRecs] buf := &bytes.Buffer{} w := raww.NewRecordWriter(buf, output.NewPrinting()) - require.NoError(t, w.Open(recMeta)) - require.NoError(t, w.WriteRecords(recs)) - require.NoError(t, w.Close()) + require.NoError(t, w.Open(ctx, recMeta)) + require.NoError(t, w.WriteRecords(ctx, recs)) + require.NoError(t, w.Close(ctx)) require.Equal(t, tc.want, buf.Bytes()) }) } @@ -59,9 +62,9 @@ func TestRecordWriter_TblBytes(t *testing.T) { buf := &bytes.Buffer{} w := raww.NewRecordWriter(buf, output.NewPrinting()) - require.NoError(t, w.Open(sink.RecMeta)) - require.NoError(t, w.WriteRecords(sink.Recs)) - require.NoError(t, w.Close()) + require.NoError(t, w.Open(th.Context, sink.RecMeta)) + require.NoError(t, w.WriteRecords(th.Context, sink.Recs)) + require.NoError(t, w.Close(th.Context)) require.Equal(t, fBytes, buf.Bytes()) _, err = gif.Decode(bytes.NewReader(buf.Bytes())) diff --git a/cli/output/tablew/configwriter.go b/cli/output/tablew/configwriter.go index 6dd7ae3bd..2035b164c 100644 --- a/cli/output/tablew/configwriter.go +++ b/cli/output/tablew/configwriter.go @@ -1,6 +1,7 @@ package tablew import ( + "context" "fmt" "io" @@ -96,16 +97,15 @@ func (w *configWriter) Options(reg *options.Registry, o options.Options) error { w.tbl.pr.ShowHeader = false } - w.doPrintOptions(reg, o, true) - return nil + return w.doPrintOptions(reg, o, true) } // Options implements output.ConfigWriter. // If printUnset is true and we're in verbose mode, unset options // are also printed. -func (w *configWriter) doPrintOptions(reg *options.Registry, o options.Options, printUnset bool) { +func (w *configWriter) doPrintOptions(reg *options.Registry, o options.Options, printUnset bool) error { if o == nil { - return + return nil } t, pr, verbose := w.tbl.tblImpl, w.tbl.pr, w.tbl.pr.Verbose @@ -166,8 +166,7 @@ func (w *configWriter) doPrintOptions(reg *options.Registry, o options.Options, } if !printUnset || !verbose { - w.tbl.appendRowsAndRenderAll(rows) - return + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // Also print the unset opts @@ -186,7 +185,7 @@ func (w *configWriter) doPrintOptions(reg *options.Registry, o options.Options, rows = append(rows, row) } - w.tbl.appendRowsAndRenderAll(rows) + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // SetOption implements output.ConfigWriter. @@ -203,8 +202,7 @@ func (w *configWriter) SetOption(o options.Options, opt options.Opt) error { // It's verbose o = options.Effective(o, opt) w.tbl.pr.ShowHeader = true - w.doPrintOptions(reg2, o, false) - return nil + return w.doPrintOptions(reg2, o, false) } // UnsetOption implements output.ConfigWriter. @@ -218,8 +216,7 @@ func (w *configWriter) UnsetOption(opt options.Opt) error { reg.Add(opt) o := options.Options{} - w.doPrintOptions(reg, o, true) - return nil + return w.doPrintOptions(reg, o, true) } func getOptColor(pr *output.Printing, opt options.Opt) *color.Color { diff --git a/cli/output/tablew/internal/texttable.go b/cli/output/tablew/internal/texttable.go index 01fc2f513..3d822b454 100644 --- a/cli/output/tablew/internal/texttable.go +++ b/cli/output/tablew/internal/texttable.go @@ -10,6 +10,7 @@ package internal import ( "bytes" + "context" "fmt" "io" "regexp" @@ -164,17 +165,20 @@ func (t *Table) getColTrans(col int) textTransFunc { } // RenderAll table output -func (t *Table) RenderAll() { +func (t *Table) RenderAll(ctx context.Context) error { if t.borders.Top { t.printLine(true) } t.printHeading() - t.printRows() + if err := t.printRows(ctx); err != nil { + return err + } if !t.rowLine && t.borders.Bottom { t.printLine(true) } t.printFooter() + return nil } // SetHeader sets table header @@ -461,10 +465,16 @@ func (t *Table) printFooter() { fmt.Fprintln(t.out) } -func (t *Table) printRows() { +func (t *Table) printRows(ctx context.Context) error { for i, lines := range t.lines { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } t.printRow(lines, i) } + return nil } // Print Row Information diff --git a/cli/output/tablew/metadatawriter.go b/cli/output/tablew/metadatawriter.go index e412f6770..46b948e10 100644 --- a/cli/output/tablew/metadatawriter.go +++ b/cli/output/tablew/metadatawriter.go @@ -2,6 +2,7 @@ package tablew import ( "cmp" + "context" "fmt" "io" "slices" @@ -45,8 +46,7 @@ func (w *mdWriter) DriverMetadata(drvrs []driver.Metadata) error { row := []string{string(md.Type), md.Description, strconv.FormatBool(md.UserDefined), md.Doc} rows = append(rows, row) } - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // TableMetadata implements output.MetadataWriter. @@ -86,8 +86,7 @@ func (w *mdWriter) doTableMeta(md *metadata.Table) error { } rows = append(rows, row) - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } func (w *mdWriter) doTableMetaVerbose(tblMeta *metadata.Table) error { @@ -129,8 +128,7 @@ func (w *mdWriter) doSourceMetaNoSchema(md *metadata.Source) error { } w.tbl.tblImpl.SetHeader(headers) - w.tbl.renderRow(row) - return nil + return w.tbl.renderRow(context.TODO(), row) } func (w *mdWriter) printTablesVerbose(tbls []*metadata.Table) error { @@ -192,8 +190,7 @@ func (w *mdWriter) printTablesVerbose(tbls []*metadata.Table) error { } } - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } func (w *mdWriter) printTables(tables []*metadata.Table) error { @@ -226,8 +223,7 @@ func (w *mdWriter) printTables(tables []*metadata.Table) error { rows = append(rows, row) } - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } func (w *mdWriter) doSourceMetaFull(md *metadata.Source) error { @@ -265,7 +261,9 @@ func (w *mdWriter) doSourceMetaFull(md *metadata.Source) error { } w.tbl.tblImpl.SetHeader(headers) - w.tbl.renderRow(row) + if err := w.tbl.renderRow(context.TODO(), row); err != nil { + return err + } if len(md.Tables) == 0 { return nil @@ -355,8 +353,7 @@ func (w *mdWriter) DBProperties(props map[string]any) error { rows = append(rows, row) } - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // Catalogs implements output.MetadataWriter. @@ -380,8 +377,7 @@ func (w *mdWriter) Catalogs(currentCatalog string, catalogs []string) error { } rows = append(rows, []string{catalog}) } - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // Verbose mode @@ -402,9 +398,7 @@ func (w *mdWriter) Catalogs(currentCatalog string, catalogs []string) error { } rows = append(rows, []string{catalog, active}) } - w.tbl.appendRowsAndRenderAll(rows) - - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // Schemata implements output.MetadataWriter. @@ -427,8 +421,7 @@ func (w *mdWriter) Schemata(currentSchema string, schemas []*metadata.Schema) er } rows = append(rows, []string{s}) } - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // Verbose mode @@ -452,7 +445,5 @@ func (w *mdWriter) Schemata(currentSchema string, schemas []*metadata.Schema) er } rows = append(rows, row) } - w.tbl.appendRowsAndRenderAll(rows) - - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } diff --git a/cli/output/tablew/recordwriter.go b/cli/output/tablew/recordwriter.go index 7df06d11d..a468ce509 100644 --- a/cli/output/tablew/recordwriter.go +++ b/cli/output/tablew/recordwriter.go @@ -1,6 +1,7 @@ package tablew import ( + "context" "io" "sync" @@ -24,18 +25,18 @@ func NewRecordWriter(out io.Writer, pr *output.Printing) output.RecordWriter { } // Open implements output.RecordWriter. -func (w *recordWriter) Open(recMeta record.Meta) error { +func (w *recordWriter) Open(_ context.Context, recMeta record.Meta) error { w.recMeta = recMeta return nil } // Flush implements output.RecordWriter. -func (w *recordWriter) Flush() error { +func (w *recordWriter) Flush(context.Context) error { return nil } // Close implements output.RecordWriter. -func (w *recordWriter) Close() error { +func (w *recordWriter) Close(ctx context.Context) error { if w.rowCount == 0 { // no data to write return nil @@ -45,18 +46,22 @@ func (w *recordWriter) Close() error { header := w.recMeta.MungedNames() w.tbl.tblImpl.SetHeader(header) - w.tbl.renderAll() - return nil + return w.tbl.renderAll(ctx) } // WriteRecords implements output.RecordWriter. -func (w *recordWriter) WriteRecords(recs []record.Record) error { +func (w *recordWriter) WriteRecords(ctx context.Context, recs []record.Record) error { w.mu.Lock() defer w.mu.Unlock() kinds := w.recMeta.Kinds() var tblRows [][]string for _, rec := range recs { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } tblRow := make([]string, len(rec)) for i, val := range rec { @@ -67,6 +72,5 @@ func (w *recordWriter) WriteRecords(recs []record.Record) error { w.rowCount++ } - w.tbl.appendRows(tblRows) - return nil + return w.tbl.appendRows(ctx, tblRows) } diff --git a/cli/output/tablew/sourcewriter.go b/cli/output/tablew/sourcewriter.go index d29906893..dc978f638 100644 --- a/cli/output/tablew/sourcewriter.go +++ b/cli/output/tablew/sourcewriter.go @@ -1,6 +1,7 @@ package tablew import ( + "context" "io" "strconv" "strings" @@ -54,8 +55,7 @@ func (w *sourceWriter) Collection(coll *source.Collection) error { w.tbl.tblImpl.SetHeaderDisable(true) w.tbl.tblImpl.SetColTrans(0, pr.Handle.SprintFunc()) w.tbl.tblImpl.SetColTrans(2, pr.Location.SprintFunc()) - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // Else print verbose @@ -83,8 +83,7 @@ func (w *sourceWriter) Collection(coll *source.Collection) error { w.tbl.tblImpl.SetColTrans(0, pr.Handle.SprintFunc()) w.tbl.tblImpl.SetColTrans(3, pr.Location.SprintFunc()) w.tbl.tblImpl.SetHeader([]string{"HANDLE", "ACTIVE", "DRIVER", "LOCATION", "OPTIONS"}) - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // Added implements output.SourceWriter. @@ -144,8 +143,7 @@ func (w *sourceWriter) doSource(coll *source.Collection, src *source.Source) err w.tbl.tblImpl.SetColTrans(2, w.tbl.pr.Location.SprintFunc()) w.tbl.tblImpl.SetHeaderDisable(true) - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } var rows [][]string @@ -165,8 +163,7 @@ func (w *sourceWriter) doSource(coll *source.Collection, src *source.Source) err w.tbl.tblImpl.SetColTrans(2, w.tbl.pr.Location.SprintFunc()) w.tbl.tblImpl.SetHeaderDisable(true) - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // Removed implements output.SourceWriter. @@ -292,8 +289,7 @@ func (w *sourceWriter) renderGroups(groups []*source.Group) error { w.tbl.tblImpl.SetColTrans(4, pr.Number.SprintFunc()) w.tbl.tblImpl.SetColTrans(5, pr.Bool.SprintFunc()) - w.tbl.appendRowsAndRenderAll(rows) - return nil + return w.tbl.appendRowsAndRenderAll(context.TODO(), rows) } // Groups implements output.SourceWriter. diff --git a/cli/output/tablew/tablew.go b/cli/output/tablew/tablew.go index 119ee8ebe..237761830 100644 --- a/cli/output/tablew/tablew.go +++ b/cli/output/tablew/tablew.go @@ -14,6 +14,7 @@ package tablew import ( + "context" "fmt" "io" "strconv" @@ -147,26 +148,37 @@ func (t *table) setTableWriterOptions() { t.tblImpl.SetHeaderTrans(t.pr.Header.SprintFunc()) } -func (t *table) appendRowsAndRenderAll(rows [][]string) { +func (t *table) appendRowsAndRenderAll(ctx context.Context, rows [][]string) error { for _, v := range rows { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } t.tblImpl.Append(v) } - t.tblImpl.RenderAll() + return t.tblImpl.RenderAll(ctx) } -func (t *table) appendRows(rows [][]string) { +func (t *table) appendRows(ctx context.Context, rows [][]string) error { for _, v := range rows { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } t.tblImpl.Append(v) } + return nil } -func (t *table) renderAll() { - t.tblImpl.RenderAll() +func (t *table) renderAll(ctx context.Context) error { + return t.tblImpl.RenderAll(ctx) } -func (t *table) renderRow(row []string) { +func (t *table) renderRow(ctx context.Context, row []string) error { t.tblImpl.Append(row) - t.tblImpl.RenderAll() // Send output + return t.tblImpl.RenderAll(ctx) // Send output } func getColorForVal(pr *output.Printing, v any) *color.Color { diff --git a/cli/output/writers.go b/cli/output/writers.go index bddff1dc4..345b1690c 100644 --- a/cli/output/writers.go +++ b/cli/output/writers.go @@ -7,6 +7,7 @@ package output import ( + "context" "io" "time" @@ -31,18 +32,18 @@ import ( type RecordWriter interface { // Open instructs the writer to prepare to write records // described by recMeta. - Open(recMeta record.Meta) error + Open(ctx context.Context, recMeta record.Meta) error // WriteRecords writes rec to the destination. - WriteRecords(recs []record.Record) error + WriteRecords(ctx context.Context, recs []record.Record) error // Flush advises the writer to flush any internal // buffer. Note that the writer may implement an independent // flushing strategy, or may not buffer at all. - Flush() error + Flush(ctx context.Context) error // Close closes the writer after flushing any internal buffer. - Close() error + Close(ctx context.Context) error } // MetadataWriter can output metadata. diff --git a/cli/output/xlsxw/xlsxw.go b/cli/output/xlsxw/xlsxw.go index d7a73c15d..52cf0fbfb 100644 --- a/cli/output/xlsxw/xlsxw.go +++ b/cli/output/xlsxw/xlsxw.go @@ -4,6 +4,7 @@ package xlsxw import ( + "context" "encoding/base64" "fmt" "io" @@ -143,7 +144,7 @@ func (w *recordWriter) getDecimalStyle(dec decimal.Decimal) (int, error) { } // Open implements output.RecordWriter. -func (w *recordWriter) Open(recMeta record.Meta) error { +func (w *recordWriter) Open(_ context.Context, recMeta record.Meta) error { w.mu.Lock() defer w.mu.Unlock() @@ -206,12 +207,12 @@ func (w *recordWriter) setColWidth(col, width int) error { } // Flush implements output.RecordWriter. -func (w *recordWriter) Flush() error { +func (w *recordWriter) Flush(context.Context) error { return nil } // Close implements output.RecordWriter. -func (w *recordWriter) Close() error { +func (w *recordWriter) Close(context.Context) error { w.mu.Lock() defer w.mu.Unlock() @@ -223,11 +224,16 @@ func (w *recordWriter) Close() error { } // WriteRecords implements output.RecordWriter. -func (w *recordWriter) WriteRecords(recs []record.Record) error { //nolint:gocognit +func (w *recordWriter) WriteRecords(ctx context.Context, recs []record.Record) error { //nolint:gocognit w.mu.Lock() defer w.mu.Unlock() for _, rec := range recs { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } rowi := w.nextRow for j, val := range rec { diff --git a/cli/output/xlsxw/xlsxw_test.go b/cli/output/xlsxw/xlsxw_test.go index ac96ac5d1..69c014f14 100644 --- a/cli/output/xlsxw/xlsxw_test.go +++ b/cli/output/xlsxw/xlsxw_test.go @@ -2,6 +2,7 @@ package xlsxw_test import ( "bytes" + "context" "encoding/base64" "fmt" "image" @@ -59,6 +60,7 @@ func TestRecordWriter(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() recMeta, recs := testh.RecordsFromTbl(t, tc.handle, tc.tbl) if tc.numRecs >= 0 { recs = recs[0:tc.numRecs] @@ -70,10 +72,10 @@ func TestRecordWriter(t *testing.T) { pr.ExcelDateFormat = xlsxw.OptDateFormat.Default() pr.ExcelTimeFormat = xlsxw.OptTimeFormat.Default() w := xlsxw.NewRecordWriter(buf, pr) - require.NoError(t, w.Open(recMeta)) + require.NoError(t, w.Open(ctx, recMeta)) - require.NoError(t, w.WriteRecords(recs)) - require.NoError(t, w.Close()) + require.NoError(t, w.WriteRecords(ctx, recs)) + require.NoError(t, w.Close(ctx)) _ = tu.WriteTemp(t, fmt.Sprintf("*.%s.test.xlsx", tc.name), buf.Bytes(), false) @@ -85,15 +87,16 @@ func TestRecordWriter(t *testing.T) { } func TestBytesEncodedAsBase64(t *testing.T) { + ctx := context.Background() recMeta, recs := testh.RecordsFromTbl(t, testsrc.BlobDB, "blobs") buf := &bytes.Buffer{} pr := output.NewPrinting() w := xlsxw.NewRecordWriter(buf, pr) - require.NoError(t, w.Open(recMeta)) + require.NoError(t, w.Open(ctx, recMeta)) - require.NoError(t, w.WriteRecords(recs)) - require.NoError(t, w.Close()) + require.NoError(t, w.WriteRecords(ctx, recs)) + require.NoError(t, w.Close(ctx)) xl, err := excelize.OpenReader(buf) require.NoError(t, err) diff --git a/cli/output/xmlw/xmlw.go b/cli/output/xmlw/xmlw.go index d657fca7e..bc77eaa69 100644 --- a/cli/output/xmlw/xmlw.go +++ b/cli/output/xmlw/xmlw.go @@ -3,6 +3,7 @@ package xmlw import ( "bytes" + "context" "encoding/base64" "encoding/xml" "fmt" @@ -58,7 +59,7 @@ func NewRecordWriter(out io.Writer, pr *output.Printing) output.RecordWriter { } // Open implements output.RecordWriter. -func (w *recordWriter) Open(recMeta record.Meta) error { +func (w *recordWriter) Open(_ context.Context, recMeta record.Meta) error { w.recMeta = recMeta var indent, newline string @@ -113,13 +114,13 @@ func monoPrint(w io.Writer, a ...any) { } // Flush implements output.RecordWriter. -func (w *recordWriter) Flush() error { +func (w *recordWriter) Flush(context.Context) error { _, err := w.outBuf.WriteTo(w.out) // resets buf return errz.Err(err) } // Close implements output.RecordWriter. -func (w *recordWriter) Close() error { +func (w *recordWriter) Close(ctx context.Context) error { w.outBuf.WriteByte('\n') if w.recsWritten { @@ -131,13 +132,13 @@ func (w *recordWriter) Close() error { w.outBuf.WriteByte('\n') - return w.Flush() + return w.Flush(ctx) } // WriteRecords implements output.RecordWriter. // Note that (by design) the XML element is omitted for any nil value // in a record. -func (w *recordWriter) WriteRecords(recs []record.Record) error { +func (w *recordWriter) WriteRecords(ctx context.Context, recs []record.Record) error { if len(recs) == 0 { return nil } @@ -150,7 +151,7 @@ func (w *recordWriter) WriteRecords(recs []record.Record) error { var err error for _, rec := range recs { - err = w.writeRecord(rec) + err = w.writeRecord(ctx, rec) if err != nil { return err } @@ -159,7 +160,7 @@ func (w *recordWriter) WriteRecords(recs []record.Record) error { return nil } -func (w *recordWriter) writeRecord(rec record.Record) error { +func (w *recordWriter) writeRecord(ctx context.Context, rec record.Record) error { var err error tmpBuf := &bytes.Buffer{} @@ -215,7 +216,7 @@ func (w *recordWriter) writeRecord(rec record.Record) error { w.outBuf.WriteString(w.tplRecEnd) if w.outBuf.Len() > w.pr.FlushThreshold { - return w.Flush() + return w.Flush(ctx) } return nil diff --git a/cli/output/xmlw/xmlw_test.go b/cli/output/xmlw/xmlw_test.go index 2efa6d59b..be632ead8 100644 --- a/cli/output/xmlw/xmlw_test.go +++ b/cli/output/xmlw/xmlw_test.go @@ -3,6 +3,7 @@ package xmlw_test import ( "bytes" + "context" "os" "testing" @@ -69,6 +70,8 @@ func TestRecordWriter_Actor(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + pr := output.NewPrinting() pr.EnableColor(tc.color) pr.Compact = !tc.pretty @@ -79,9 +82,9 @@ func TestRecordWriter_Actor(t *testing.T) { buf := &bytes.Buffer{} w := xmlw.NewRecordWriter(buf, pr) - require.NoError(t, w.Open(recMeta)) - require.NoError(t, w.WriteRecords(recs)) - require.NoError(t, w.Close()) + require.NoError(t, w.Open(ctx, recMeta)) + require.NoError(t, w.WriteRecords(ctx, recs)) + require.NoError(t, w.Close(ctx)) require.Equal(t, tc.want, buf.String()) }) @@ -89,6 +92,7 @@ func TestRecordWriter_Actor(t *testing.T) { } func TestRecordWriter_TblTypes(t *testing.T) { + ctx := context.Background() pr := output.NewPrinting() pr.EnableColor(false) @@ -96,9 +100,9 @@ func TestRecordWriter_TblTypes(t *testing.T) { buf := &bytes.Buffer{} w := xmlw.NewRecordWriter(buf, pr) - require.NoError(t, w.Open(recMeta)) - require.NoError(t, w.WriteRecords(recs)) - require.NoError(t, w.Close()) + require.NoError(t, w.Open(ctx, recMeta)) + require.NoError(t, w.WriteRecords(ctx, recs)) + require.NoError(t, w.Close(ctx)) want, err := os.ReadFile("testdata/tbl_types.xml") require.NoError(t, err) diff --git a/cli/output/yamlw/recordwriter.go b/cli/output/yamlw/recordwriter.go index cfb41f190..c6545a34c 100644 --- a/cli/output/yamlw/recordwriter.go +++ b/cli/output/yamlw/recordwriter.go @@ -2,6 +2,7 @@ package yamlw import ( "bytes" + "context" "io" "strconv" "sync" @@ -41,7 +42,7 @@ type recordWriter struct { } // Open implements output.RecordWriter. -func (w *recordWriter) Open(recMeta record.Meta) error { +func (w *recordWriter) Open(_ context.Context, recMeta record.Meta) error { w.recMeta = recMeta w.fieldNames = w.recMeta.MungedNames() w.buf = &bytes.Buffer{} @@ -94,7 +95,7 @@ func (w *recordWriter) Open(recMeta record.Meta) error { } // WriteRecords implements output.RecordWriter. -func (w *recordWriter) WriteRecords(recs []record.Record) error { +func (w *recordWriter) WriteRecords(ctx context.Context, recs []record.Record) error { w.mu.Lock() defer w.mu.Unlock() @@ -112,6 +113,11 @@ func (w *recordWriter) WriteRecords(recs []record.Record) error { ) for i, rec := range recs { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } buf.WriteString("- ") for j := range rec { @@ -163,7 +169,7 @@ func (w *recordWriter) WriteRecords(recs []record.Record) error { } // Flush implements output.RecordWriter. -func (w *recordWriter) Flush() error { +func (w *recordWriter) Flush(context.Context) error { w.mu.Lock() defer w.mu.Unlock() _, err := w.buf.WriteTo(w.out) @@ -171,8 +177,8 @@ func (w *recordWriter) Flush() error { } // Close implements output.RecordWriter. -func (w *recordWriter) Close() error { - return w.Flush() +func (w *recordWriter) Close(context.Context) error { + return w.Flush(context.TODO()) } // renderTime renders the *time.Time val into a fully-rendered string diff --git a/cli/output/yamlw/yamlw_test.go b/cli/output/yamlw/yamlw_test.go index e56b5ebdb..efd38eead 100644 --- a/cli/output/yamlw/yamlw_test.go +++ b/cli/output/yamlw/yamlw_test.go @@ -34,12 +34,12 @@ func TestRecordWriter(t *testing.T) { pr := output.NewPrinting() pr.EnableColor(false) recw := yamlw.NewRecordWriter(buf, pr) - require.NoError(t, recw.Open(sink.RecMeta)) + require.NoError(t, recw.Open(th.Context, sink.RecMeta)) - err = recw.WriteRecords(sink.Recs) + err = recw.WriteRecords(th.Context, sink.Recs) require.NoError(t, err) - require.NoError(t, recw.Flush()) - require.NoError(t, recw.Close()) + require.NoError(t, recw.Flush(th.Context)) + require.NoError(t, recw.Close(th.Context)) want2 := want _ = want2 diff --git a/testh/record.go b/testh/record.go index 4906f92a9..c489170aa 100644 --- a/testh/record.go +++ b/testh/record.go @@ -1,6 +1,7 @@ package testh import ( + "context" "fmt" "reflect" "sync" @@ -148,7 +149,7 @@ func (r *RecordSink) Result() any { } // Open implements libsq.RecordWriter. -func (r *RecordSink) Open(recMeta record.Meta) error { +func (r *RecordSink) Open(_ context.Context, recMeta record.Meta) error { r.mu.Lock() defer r.mu.Unlock() @@ -157,7 +158,7 @@ func (r *RecordSink) Open(recMeta record.Meta) error { } // WriteRecords implements libsq.RecordWriter. -func (r *RecordSink) WriteRecords(recs []record.Record) error { +func (r *RecordSink) WriteRecords(_ context.Context, recs []record.Record) error { r.mu.Lock() defer r.mu.Unlock() @@ -166,7 +167,7 @@ func (r *RecordSink) WriteRecords(recs []record.Record) error { } // Flush implements libsq.RecordWriter. -func (r *RecordSink) Flush() error { +func (r *RecordSink) Flush(context.Context) error { r.mu.Lock() defer r.mu.Unlock() r.Flushed = append(r.Flushed, time.Now()) @@ -174,7 +175,7 @@ func (r *RecordSink) Flush() error { } // Close implements libsq.RecordWriter. -func (r *RecordSink) Close() error { +func (r *RecordSink) Close(context.Context) error { r.mu.Lock() defer r.mu.Unlock() r.Closed = append(r.Closed, time.Now()) From 8cfbef84ec14ca7bca41f168b978af97021dcab5 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 5 Dec 2023 20:57:18 -0700 Subject: [PATCH 054/195] deferred render delay; NewWriter panics --- go.mod | 22 +++++++++++++++++----- go.sum | 6 ++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 03570c7ff..370414471 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21 require ( github.com/Masterminds/sprig/v3 v3.2.3 + github.com/a8m/tree v0.0.0-20230208161321-36ae24ddad15 github.com/alessio/shellescape v1.4.2 github.com/antlr4-go/antlr/v4 v4.13.0 github.com/c2h5oh/datasize v0.0.0-20220606134207-859f65c6625b @@ -16,7 +17,6 @@ require ( github.com/google/uuid v1.4.0 github.com/h2non/filetype v1.1.3 github.com/jackc/pgx/v5 v5.5.0 - github.com/lmittmann/tint v1.0.3 github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.20 github.com/mattn/go-runewidth v0.0.15 @@ -26,7 +26,6 @@ require ( github.com/muesli/mango-cobra v1.2.0 github.com/muesli/roff v0.1.0 github.com/ncruces/go-strftime v0.1.9 - github.com/neilotoole/fscache v0.0.0-20231203162946-c9808f16552e github.com/neilotoole/shelleditor v0.4.1 github.com/neilotoole/slogt v1.1.0 github.com/nightlyone/lockfile v1.0.0 @@ -39,7 +38,6 @@ require ( github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.8.4 - github.com/vbauerster/mpb/v8 v8.7.0 github.com/xo/dburl v0.19.1 github.com/xuri/excelize/v2 v2.8.0 go.uber.org/atomic v1.11.0 @@ -57,11 +55,9 @@ require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.0 // indirect github.com/VividCortex/ewma v1.2.0 // indirect - github.com/a8m/tree v0.0.0-20230208161321-36ae24ddad15 // indirect github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/djherbis/atime v1.1.0 // indirect - github.com/djherbis/stream v1.4.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/huandu/xstrings v1.3.3 // indirect @@ -93,4 +89,20 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) +// See: https://github.com/vbauerster/mpb/issues/136 +require github.com/vbauerster/mpb/v8 v8.7.1-0.20231205062852-da3162c67234 // indirect +//require github.com/vbauerster/mpb/v8 v8.7.0 +// +//// See: https://github.com/vbauerster/mpb/issues/136 +//replace github.com/vbauerster/mpb/v8 v8.7.0 => ../sq-mpb + +// See: https://github.com/djherbis/fscache/pull/21 +require github.com/neilotoole/fscache v0.0.0-20231203162946-c9808f16552e + +require ( + github.com/djherbis/stream v1.4.0 // indirect + +) + +// See: https://github.com/djherbis/stream/pull/11 replace github.com/djherbis/stream v1.4.0 => github.com/neilotoole/djherbis-stream v0.0.0-20231203160853-609f47afedda diff --git a/go.sum b/go.sum index 65ddf81a5..190ecd368 100644 --- a/go.sum +++ b/go.sum @@ -92,8 +92,6 @@ github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0 github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= -github.com/lmittmann/tint v1.0.3 h1:W5PHeA2D8bBJVvabNfQD/XW9HPLZK1XoPZH0cq8NouQ= -github.com/lmittmann/tint v1.0.3/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -186,8 +184,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/vbauerster/mpb/v8 v8.7.0 h1:n2LTGyol7qqNBcLQn8FL5Bga2O8CGF75OOYsJVFsfMg= -github.com/vbauerster/mpb/v8 v8.7.0/go.mod h1:0RgdqeTpu6cDbdWeSaDvEvfgm9O598rBnRZ09HKaV0k= +github.com/vbauerster/mpb/v8 v8.7.1-0.20231205062852-da3162c67234 h1:ZsOQFNOwxbDqlxHc9wUW2skA4QMXMZyCOVngFdbrzJE= +github.com/vbauerster/mpb/v8 v8.7.1-0.20231205062852-da3162c67234/go.mod h1:0RgdqeTpu6cDbdWeSaDvEvfgm9O598rBnRZ09HKaV0k= github.com/xo/dburl v0.19.1 h1:z/K2i8zVf6aRwQ8Szz7MGEUw0VC2472D9SlBqdHDQCU= github.com/xo/dburl v0.19.1/go.mod h1:B7/G9FGungw6ighV8xJNwWYQPMfn3gsi2sn5SE8Bzco= github.com/xuri/efp v0.0.0-20230802181842-ad255f2331ca h1:uvPMDVyP7PXMMioYdyPH+0O+Ta/UO1WFfNYMO3Wz0eg= From 3135f63c5b5cc461611d1c4264634502a7d36c60 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 6 Dec 2023 08:08:44 -0700 Subject: [PATCH 055/195] pkg progress seems to be working --- .gitignore | 1 - cli/cmd_xtest.go | 18 +-- drivers/csv/csv_test.go | 78 +---------- drivers/csv/testdata/.gitignore | 1 + libsq/core/lg/lga/lga.go | 1 + libsq/core/progress/{progressio.go => io.go} | 130 +++++++++---------- libsq/core/progress/progress.go | 26 ++-- libsq/driver/grips.go | 3 +- libsq/source/files.go | 1 - testh/gen.go | 92 +++++++++++++ testh/testh.go | 3 +- 11 files changed, 177 insertions(+), 177 deletions(-) create mode 100644 drivers/csv/testdata/.gitignore rename libsq/core/progress/{progressio.go => io.go} (70%) create mode 100644 testh/gen.go diff --git a/.gitignore b/.gitignore index d6dca93d5..ef74c33fd 100644 --- a/.gitignore +++ b/.gitignore @@ -55,4 +55,3 @@ goreleaser-test.sh /cli/test.db /*.db /.CHANGELOG.delta.md -/drivers/csv/testdata/payment-large.csv diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go index 38254f076..3f12cfab1 100644 --- a/cli/cmd_xtest.go +++ b/cli/cmd_xtest.go @@ -40,39 +40,39 @@ func execXTestMbp(cmd *cobra.Command, _ []string) error { pb := progress.New(ctx, ru.ErrOut, 1*time.Millisecond, progress.DefaultColors()) ctx = progress.NewContext(ctx, pb) - if err := doBigRead2(ctx); err != nil { + if err := doProgressByteCounterRead(ctx); err != nil { return err } return ru.Writers.Version.Version(buildinfo.Get(), buildinfo.Get().Version, hostinfo.Get()) } -func doBigRead2(ctx context.Context) error { +func doProgressByteCounterRead(ctx context.Context) error { pb := progress.FromContext(ctx) - spinner := pb.NewByteCounter("Ingest data test", -1) - defer spinner.Stop() + bar := pb.NewByteCounter("Ingest data test", -1) + defer bar.Stop() maxSleep := 100 * time.Millisecond - jr := ioz.LimitRandReader(100000) + lr := ioz.LimitRandReader(100000) b := make([]byte, 1024) LOOP: for { select { case <-ctx.Done(): - spinner.Stop() + bar.Stop() break LOOP default: } - n, err := jr.Read(b) + n, err := lr.Read(b) if err != nil { - spinner.Stop() + bar.Stop() break } - spinner.IncrBy(n) + bar.IncrBy(n) time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) //nolint:gosec } diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index 924b9e538..b16d5e8be 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -2,12 +2,7 @@ package csv_test import ( "context" - stdcsv "encoding/csv" - "fmt" - "math/rand" - "os" "path/filepath" - "strconv" "testing" "time" @@ -15,8 +10,6 @@ import ( "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/text/language" - "golang.org/x/text/message" "github.com/neilotoole/sq/cli/testrun" "github.com/neilotoole/sq/drivers/csv" @@ -340,73 +333,10 @@ func TestDatetime(t *testing.T) { } // TestIngestLargeCSV generates a large CSV file. -// At count = 5000000, the generated file is ~500MB. +// At count = 5,000,000, the generated file is ~500MB. +// This test is skipped by default. +// FIXME: Delete TestGenerateLargeCSV. func TestGenerateLargeCSV(t *testing.T) { t.Skip() - const count = 5000000 // Generates ~500MB file - start := time.Now() - header := []string{ - "payment_id", - "customer_id", - "name", - "staff_id", - "rental_id", - "amount", - "payment_date", - "last_update", - } - - f, err := os.OpenFile( - "testdata/payment-large.csv", - os.O_CREATE|os.O_WRONLY|os.O_TRUNC, - 0o600, - ) - require.NoError(t, err) - t.Cleanup(func() { _ = f.Close() }) - - w := stdcsv.NewWriter(f) - require.NoError(t, w.Write(header)) - - rec := make([]string, len(header)) - amount := decimal.New(50000, -2) - paymentUTC := time.Now().UTC() - lastUpdateUTC := time.Now().UTC() - p := message.NewPrinter(language.English) - for i := 0; i < count; i++ { - if i%100000 == 0 { - // Flush occasionally - w.Flush() - } - - rec[0] = strconv.Itoa(i + 1) // payment id, always unique - rec[1] = strconv.Itoa(rand.Intn(100)) // customer_id, one of 100 customers - rec[2] = "Alice " + rec[1] // name - rec[3] = strconv.Itoa(rand.Intn(10)) // staff_id - rec[4] = strconv.Itoa(i + 3) // rental_id, always unique - f64 := amount.InexactFloat64() - // rec[5] = p.Sprintf("%.2f", f64) // amount - rec[5] = fmt.Sprintf("%.2f", f64) // amount - amount = amount.Add(decimal.New(33, -2)) - rec[6] = timez.TimestampUTC(paymentUTC) // payment_date - paymentUTC = paymentUTC.Add(time.Minute) - rec[7] = timez.TimestampUTC(lastUpdateUTC) // last_update - lastUpdateUTC = lastUpdateUTC.Add(time.Minute + time.Second) - err = w.Write(rec) - require.NoError(t, err) - } - - w.Flush() - require.NoError(t, w.Error()) - require.NoError(t, f.Close()) - - fi, err := os.Stat(f.Name()) - require.NoError(t, err) - - t.Logf( - "Wrote %s records in %s, total size %s, to: %s", - p.Sprintf("%d", count), - time.Since(start).Round(time.Millisecond), - stringz.ByteSized(fi.Size(), 1, ""), - f.Name(), - ) + testh.GenerateLargeCSV(t, "testdata/payment-large.csv") } diff --git a/drivers/csv/testdata/.gitignore b/drivers/csv/testdata/.gitignore new file mode 100644 index 000000000..55a90cfc6 --- /dev/null +++ b/drivers/csv/testdata/.gitignore @@ -0,0 +1 @@ +*.gen.csv diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index fa31dca30..fafba9873 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -33,6 +33,7 @@ const ( Kind = "kind" Loc = "loc" Lock = "lock" + Name = "name" New = "new" Old = "old" Opts = "opts" diff --git a/libsq/core/progress/progressio.go b/libsq/core/progress/io.go similarity index 70% rename from libsq/core/progress/progressio.go rename to libsq/core/progress/io.go index 7e880d3ea..240fdde6b 100644 --- a/libsq/core/progress/progressio.go +++ b/libsq/core/progress/io.go @@ -1,23 +1,8 @@ -/* -Copyright 2018 Olivier Mengué - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -// This code is derived from github.com/dolmen-go/contextio. - package progress +// Acknowledgement: The reader & writer implementations were originally +// adapted from github.com/dolmen-go/contextio. + import ( "context" "io" @@ -30,7 +15,7 @@ import ( // generates a progress bar as bytes are written to w. It is expected that ctx // contains a *progress.Progress, as returned by progress.FromContext. If not, // this function delegates to contextio.NewWriter: the returned writer will -// still be context-ware. See the contextio package for more details. +// still be context-aware. See the contextio package for more details. // // Context state is checked BEFORE every Write. // @@ -57,14 +42,16 @@ func NewWriter(ctx context.Context, msg string, size int64, w io.Writer) Writer pb := FromContext(ctx) if pb == nil { - return writerWrapper{contextio.NewWriter(ctx, w)} + // No progress bar in context, so we delegate to contextio. + return writerAdapter{contextio.NewWriter(ctx, w)} } - spinner := pb.NewByteCounter(msg, size) + bar := pb.NewByteCounter(msg, size) return &progCopier{progWriter{ ctx: ctx, - w: spinner.bar.ProxyWriter(w), - spinner: spinner, + delayCh: pb.delayCh, + w: w, + b: bar, }} } @@ -73,31 +60,33 @@ var _ io.WriteCloser = (*progWriter)(nil) type progWriter struct { ctx context.Context w io.Writer - spinner *Bar + delayCh <-chan struct{} + b *Bar } -// Write implements [io.Writer], but with context awareness. +// Write implements [io.Writer], but with context and progress interaction. func (w *progWriter) Write(p []byte) (n int, err error) { select { case <-w.ctx.Done(): - w.spinner.Stop() + w.b.Stop() return 0, w.ctx.Err() + case <-w.delayCh: + w.b.initBarOnce.Do(w.b.initBar) default: - n, err = w.w.Write(p) - if err != nil { - w.spinner.Stop() - } - return n, err } -} -// Close implements [io.WriteCloser], but with context awareness. -func (w *progWriter) Close() error { - if w == nil { - return nil + n, err = w.w.Write(p) + w.b.IncrBy(n) + if err != nil { + w.b.Stop() } + return n, err +} - w.spinner.Stop() +// Close implements [io.WriteCloser], but with context and +// progress interaction. +func (w *progWriter) Close() error { + w.b.Stop() var closeErr error if c, ok := w.w.(io.Closer); ok { @@ -135,11 +124,12 @@ func NewReader(ctx context.Context, msg string, size int64, r io.Reader) io.Read return contextio.NewReader(ctx, r) } - spinner := pb.NewByteCounter(msg, size) + b := pb.NewByteCounter(msg, size) pr := &progReader{ ctx: ctx, - r: spinner.bar.ProxyReader(r), - spinner: spinner, + delayCh: pb.delayCh, + r: r, + b: b, } return pr } @@ -149,16 +139,13 @@ var _ io.ReadCloser = (*progReader)(nil) type progReader struct { ctx context.Context r io.Reader - spinner *Bar + delayCh <-chan struct{} + b *Bar } // Close implements [io.ReadCloser], but with context awareness. func (r *progReader) Close() error { - if r == nil { - return nil - } - - r.spinner.Stop() + r.b.Stop() var closeErr error if c, ok := r.r.(io.ReadCloser); ok { @@ -173,19 +160,23 @@ func (r *progReader) Close() error { } } -// Read implements [io.Reader], but with context awareness. +// Read implements [io.Reader], but with context and progress interaction. func (r *progReader) Read(p []byte) (n int, err error) { select { case <-r.ctx.Done(): - r.spinner.Stop() + r.b.Stop() return 0, r.ctx.Err() + case <-r.delayCh: + r.b.initBarOnce.Do(r.b.initBar) default: - n, err = r.r.Read(p) - if err != nil { - r.spinner.Stop() - } - return n, err } + + n, err = r.r.Read(p) + r.b.IncrBy(n) + if err != nil { + r.b.Stop() + } + return n, err } var _ io.ReaderFrom = (*progCopier)(nil) @@ -201,16 +192,19 @@ type Writer interface { Stop() } -var _ Writer = (*writerWrapper)(nil) +var _ Writer = (*writerAdapter)(nil) -// writerWrapper wraps an io.Writer to implement [progress.Writer]. -type writerWrapper struct { +// writerAdapter wraps an io.Writer to implement [progress.Writer]. +// This is only used, by [NewWriter], when there is no progress bar +// in the context, and thus [NewWriter] delegates to contextio.NewWriter, +// but we still need to implement [progress.Writer]. +type writerAdapter struct { io.Writer } // Close implements [io.WriteCloser]. If the underlying // writer implements [io.Closer], it will be closed. -func (w writerWrapper) Close() error { +func (w writerAdapter) Close() error { if c, ok := w.Writer.(io.Closer); ok { return c.Close() } @@ -218,7 +212,7 @@ func (w writerWrapper) Close() error { } // Stop implements [Writer] and is no-op. -func (w writerWrapper) Stop() { +func (w writerAdapter) Stop() { } var _ Writer = (*progCopier)(nil) @@ -229,37 +223,33 @@ type progCopier struct { // Stop implements [progress.Writer]. func (w *progCopier) Stop() { - if w == nil || w.spinner == nil { - return - } - - w.spinner.Stop() + w.b.Stop() } -// ReadFrom implements interface [io.ReaderFrom], but with context awareness. -// -// This should allow efficient copying allowing writer or reader to define the chunk size. +// ReadFrom implements [io.ReaderFrom], but with context and +// progress interaction. func (w *progCopier) ReadFrom(r io.Reader) (n int64, err error) { if _, ok := w.w.(io.ReaderFrom); ok { // Let the original Writer decide the chunk size. rdr := &progReader{ ctx: w.ctx, - r: w.spinner.bar.ProxyReader(r), - spinner: w.spinner, + delayCh: w.delayCh, + r: r, + b: w.b, } return io.Copy(w.progWriter.w, rdr) } select { case <-w.ctx.Done(): - w.spinner.Stop() + w.b.Stop() return 0, w.ctx.Err() default: // The original Writer is not a ReaderFrom. // Let the Reader decide the chunk size. n, err = io.Copy(&w.progWriter, r) if err != nil { - w.spinner.Stop() + w.b.Stop() } return n, err } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 85d0d1d6a..06b887cef 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -77,8 +77,8 @@ const ( // is that both the Progress.pc and Bar.bar are lazily initialized. // The Progress.pc (progress container) is initialized on the first // call to one of the Progress.NewX methods. The Bar.bar is initialized -// only after the render delay has expired. The details are ugly. Hopefully -// this can all be simplified once the mpb bug is fixed. +// only after the render delay has expired. The details are ugly. +// Hopefully this can all be simplified once the mpb bug is fixed. // New returns a new Progress instance, which is a container for progress bars. // The returned Progress instance is safe for concurrent use, and all of its @@ -91,8 +91,6 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors lg.FromContext(ctx).Debug("New progress widget", "delay", delay) var cancelFn context.CancelFunc - ogCtx := ctx - _ = ogCtx ctx, cancelFn = context.WithCancel(ctx) if colors == nil { @@ -108,7 +106,6 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors } p.pcInit = func() { - lg.FromContext(ctx).Debug("Initializing progress widget") opts := []mpb.ContainerOption{ mpb.WithOutput(out), mpb.WithWidth(boxWidth), @@ -116,8 +113,7 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } if delay > 0 { - delayCh := renderDelay(ctx, p, delay) - opts = append(opts, mpb.WithRenderDelay(delayCh)) + delayCh := renderDelay(p, delay) p.delayCh = delayCh } else { delayCh := make(chan struct{}) @@ -154,7 +150,6 @@ type Progress struct { delayCh <-chan struct{} // stopped is set to true when Stop is called. - // REVISIT: Do we really need stopped, or can we rely on ctx.Done()? stopped bool colors *Colors @@ -436,20 +431,16 @@ func (b *Bar) Stop() { } // renderDelay returns a channel that will be closed after d, -// or if ctx is done. Arg callback is invoked after the delay. -func renderDelay(ctx context.Context, p *Progress, d time.Duration) <-chan struct{} { +// at which point p.InitBars will be called. +func renderDelay(p *Progress, d time.Duration) <-chan struct{} { ch := make(chan struct{}) t := time.NewTimer(d) go func() { defer close(ch) defer t.Stop() - select { - case <-ctx.Done(): - lg.FromContext(ctx).Debug("Render delay via ctx.Done") - case <-t.C: - lg.FromContext(ctx).Debug("Render delay via timer") - p.initBars() - } + + <-t.C + p.initBars() }() return ch } @@ -516,7 +507,6 @@ func barStyle(c *color.Color) mpb.BarStyleComposer { } frames := []string{"∙", "●", "●", "●", "∙"} - return mpb.BarStyle(). Lbound(" ").Rbound(" "). Filler("∙").FillerMeta(clr). diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 8c8885299..558e02c4c 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -9,8 +9,6 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/source/drivertype" - "github.com/nightlyone/lockfile" "github.com/neilotoole/sq/libsq/core/cleanup" @@ -22,6 +20,7 @@ import ( "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/retry" "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/drivertype" ) var _ GripOpener = (*Grips)(nil) diff --git a/libsq/source/files.go b/libsq/source/files.go index f5a903095..6c5186f50 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -159,7 +159,6 @@ func (fs *Files) addStdin(ctx context.Context, f *os.File) error { lw := ioz.NewWrittenWriter(w) fs.stdinLength = lw.Written - // df := ioz.DelayReader(f, time.Microsecond*500, true) // FIXME: Delete cr := contextio.NewReader(ctx, f) pw := progress.NewWriter(ctx, "Reading stdin", -1, lw) diff --git a/testh/gen.go b/testh/gen.go new file mode 100644 index 000000000..50b9153f3 --- /dev/null +++ b/testh/gen.go @@ -0,0 +1,92 @@ +package testh + +import ( + stdcsv "encoding/csv" + "fmt" + "math/rand" + "os" + "strconv" + "testing" + "time" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + "golang.org/x/text/language" + "golang.org/x/text/message" + + "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/neilotoole/sq/libsq/core/timez" +) + +// GenerateLargeCSV generates a large CSV file. +// At count = 5000000, the generated file is ~500MB. +// +//nolint:gosec +func GenerateLargeCSV(t *testing.T, fp string) { + const count = 5000000 // Generates ~500MB file + start := time.Now() + header := []string{ + "payment_id", + "customer_id", + "name", + "staff_id", + "rental_id", + "amount", + "payment_date", + "last_update", + } + + f, err := os.OpenFile( + fp, + os.O_CREATE|os.O_WRONLY|os.O_TRUNC, + 0o600, + ) + require.NoError(t, err) + t.Cleanup(func() { _ = f.Close() }) + + w := stdcsv.NewWriter(f) + require.NoError(t, w.Write(header)) + + rec := make([]string, len(header)) + amount := decimal.New(50000, -2) + paymentUTC := time.Now().UTC() + lastUpdateUTC := time.Now().UTC() + p := message.NewPrinter(language.English) + for i := 0; i < count; i++ { + if i%100000 == 0 { + // Flush occasionally + w.Flush() + } + + rec[0] = strconv.Itoa(i + 1) // payment id, always unique + rec[1] = strconv.Itoa(rand.Intn(100)) // customer_id, one of 100 customers + rec[2] = "Alice " + rec[1] // name + rec[3] = strconv.Itoa(rand.Intn(10)) // staff_id + rec[4] = strconv.Itoa(i + 3) // rental_id, always unique + f64 := amount.InexactFloat64() + // rec[5] = p.Sprintf("%.2f", f64) // amount + rec[5] = fmt.Sprintf("%.2f", f64) // amount + amount = amount.Add(decimal.New(33, -2)) + rec[6] = timez.TimestampUTC(paymentUTC) // payment_date + paymentUTC = paymentUTC.Add(time.Minute) + rec[7] = timez.TimestampUTC(lastUpdateUTC) // last_update + lastUpdateUTC = lastUpdateUTC.Add(time.Minute + time.Second) + err = w.Write(rec) + require.NoError(t, err) + } + + w.Flush() + require.NoError(t, w.Error()) + require.NoError(t, f.Close()) + + fi, err := os.Stat(f.Name()) + require.NoError(t, err) + + t.Logf( + "Wrote %s records in %s, total size %s, to: %s", + p.Sprintf("%d", count), + time.Since(start).Round(time.Millisecond), + stringz.ByteSized(fi.Size(), 1, ""), + f.Name(), + ) +} diff --git a/testh/testh.go b/testh/testh.go index 43be46676..b55bca402 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -14,8 +14,6 @@ import ( "testing" "time" - "github.com/neilotoole/sq/libsq/core/lg/devlog" - "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -43,6 +41,7 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/devlog" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" From f0167af0814d61bb68df51ddaf1d82e904de8662 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 6 Dec 2023 09:03:13 -0700 Subject: [PATCH 056/195] Add more progress to excel --- drivers/xlsx/ingest.go | 49 ++++++++++++++++----------------- drivers/xlsx/xlsx.go | 2 +- libsq/core/progress/progress.go | 30 ++++++++++++++++++++ 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 41487e2bd..935eecebb 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -4,10 +4,11 @@ import ( "context" "database/sql" "fmt" - "slices" "strings" "time" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/samber/lo" excelize "github.com/xuri/excelize/v2" "golang.org/x/sync/errgroup" @@ -27,10 +28,6 @@ import ( const msgCloseRowIter = "Close Excel row iterator" -func hasSheet(xfile *excelize.File, sheetName string) bool { - return slices.Contains(xfile.GetSheetList(), sheetName) -} - // sheetTable maps a sheet to a database table. type sheetTable struct { sheet *xSheet @@ -86,10 +83,7 @@ func (xs *xSheet) loadSampleRows(ctx context.Context, sampleSize int) error { // ingestXLSX loads the data in xfile into destGrip. // If includeSheetNames is non-empty, only the named sheets are ingested. -func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, - xfile *excelize.File, includeSheetNames []string, -) error { - // FIXME: delete includeSheetNames +func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, xfile *excelize.File) error { log := lg.FromContext(ctx) start := time.Now() log.Debug("Beginning import from XLSX", @@ -97,19 +91,11 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, lga.Target, destGrip.Source()) var sheets []*xSheet - if len(includeSheetNames) > 0 { - for _, sheetName := range includeSheetNames { - if !hasSheet(xfile, sheetName) { - return errz.Errorf("sheet {%s} not found", sheetName) - } - sheets = append(sheets, &xSheet{file: xfile, name: sheetName}) - } - } else { - sheetNames := xfile.GetSheetList() - sheets = make([]*xSheet, len(sheetNames)) - for i := range sheetNames { - sheets[i] = &xSheet{file: xfile, name: sheetNames[i]} - } + + sheetNames := xfile.GetSheetList() + sheets = make([]*xSheet, len(sheetNames)) + for i := range sheetNames { + sheets[i] = &xSheet{file: xfile, name: sheetNames[i]} } srcIngestHeader := getSrcIngestHeader(src.Options) @@ -118,6 +104,14 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, return err } + lg.FromContext(ctx).Error("count is woah", lga.Count, len(sheetTbls)) + bar := progress.FromContext(ctx).NewUnitTotalCounter( + "Ingesting sheets", + "sheet", + int64(len(sheetTbls)), + ) + defer bar.Stop() + for _, sheetTbl := range sheetTbls { if sheetTbl == nil { // tblDef can be nil if its sheet is empty (has no data). @@ -139,22 +133,25 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, lga.Target, destGrip.Source(), lga.Elapsed, time.Since(start)) - var imported, skipped int + var ingestCount, skipped int for i := range sheetTbls { if sheetTbls[i] == nil { // tblDef can be nil if its sheet is empty (has no data). skipped++ + bar.IncrBy(1) continue } + time.Sleep(time.Millisecond * 100) if err = ingestSheetToTable(ctx, destGrip, sheetTbls[i]); err != nil { return err } - imported++ + ingestCount++ + bar.IncrBy(1) } - log.Debug("Sheets imported", - lga.Count, imported, + log.Debug("Sheets ingested", + lga.Count, ingestCount, "skipped", skipped, lga.From, src, lga.To, destGrip.Source(), diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index 180fd7865..fef18e316 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -85,7 +85,7 @@ func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Grip, err defer lg.WarnIfCloseError(log, lgm.CloseFileReader, xfile) - if err = ingestXLSX(ctx, p.src, destGrip, xfile, nil); err != nil { + if err = ingestXLSX(ctx, p.src, destGrip, xfile); err != nil { return err } return nil diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 06b887cef..686e384e4 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -263,6 +263,36 @@ func (p *Progress) NewUnitCounter(msg, unit string) *Bar { return p.newBar(msg, -1, style, decorator) } +func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { + if p == nil { + return nil + } + + if total <= 0 { + return p.NewUnitCounter(msg, unit) + } + + p.mu.Lock() + defer p.mu.Unlock() + + style := barStyle(p.colors.Filler) + // counter := decor.CountersNoUnit("%d / %d") + // counter = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") + + decorator := decor.Any(func(statistics decor.Statistics) string { + s := humanize.Comma(statistics.Current) + " / " + humanize.Comma(statistics.Total) + if unit != "" { + s += " " + english.PluralWord(int(statistics.Current), unit, "") + } + return s + }) + decorator = colorize(decorator, p.colors.Size) + + // style := spinnerStyle(p.colors.Filler) + + return p.newBar(msg, total, style, decorator) +} + // NewByteCounter returns a new progress bar whose metric is the count // of bytes processed. If the size is unknown, set arg size to -1. The caller // is ultimately responsible for calling [Bar.Stop] on the returned Bar. From dfbf3b513588d5d35640bf6e2294c25749d6b323 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 6 Dec 2023 09:15:41 -0700 Subject: [PATCH 057/195] Dialing in xlsx progress --- drivers/xlsx/ingest.go | 2 +- libsq/core/progress/progress.go | 29 +++++++++++++++++++++++------ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 935eecebb..b7b508ad0 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -107,7 +107,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, x lg.FromContext(ctx).Error("count is woah", lga.Count, len(sheetTbls)) bar := progress.FromContext(ctx).NewUnitTotalCounter( "Ingesting sheets", - "sheet", + "", int64(len(sheetTbls)), ) defer bar.Stop() diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 686e384e4..9137f691e 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -17,10 +17,13 @@ package progress import ( "context" "io" + "strings" "sync" "sync/atomic" "time" + "github.com/neilotoole/sq/libsq/core/stringz" + humanize "github.com/dustin/go-humanize" "github.com/dustin/go-humanize/english" "github.com/fatih/color" @@ -62,6 +65,7 @@ func FromContext(ctx context.Context) *Progress { } const ( + msgLength = 22 barWidth = 28 boxWidth = 64 refreshRate = 150 * time.Millisecond @@ -263,6 +267,18 @@ func (p *Progress) NewUnitCounter(msg, unit string) *Bar { return p.newBar(msg, -1, style, decorator) } +// NewUnitTotalCounter returns a new determinate bar whose label +// metric is the plural of the provided unit. The caller is ultimately +// responsible for calling [Bar.Stop] on the returned Bar. However, +// the returned Bar is also added to the Progress's cleanup list, so +// it will be called automatically when the Progress is shut down, but that +// may be later than the actual conclusion of the Bar's work. +// +// This produces output similar to: +// +// Ingesting sheets ∙∙∙∙∙● 4 / 16 sheets +// +// Note that the unit arg is automatically pluralized. func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { if p == nil { return nil @@ -276,9 +292,6 @@ func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { defer p.mu.Unlock() style := barStyle(p.colors.Filler) - // counter := decor.CountersNoUnit("%d / %d") - // counter = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") - decorator := decor.Any(func(statistics decor.Statistics) string { s := humanize.Comma(statistics.Current) + " / " + humanize.Comma(statistics.Total) if unit != "" { @@ -287,9 +300,6 @@ func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { return s }) decorator = colorize(decorator, p.colors.Size) - - // style := spinnerStyle(p.colors.Filler) - return p.newBar(msg, total, style, decorator) } @@ -349,6 +359,13 @@ func (p *Progress) newBar(msg string, total int64, total = 0 } + switch { + case len(msg) < msgLength: + msg += strings.Repeat(" ", msgLength-len(msg)) + case len(msg) > msgLength: + msg = stringz.TrimLenMiddle(msg, msgLength) + } + b := &Bar{ p: p, incrStash: &atomic.Int64{}, From be294566ff4ebce27d7545a88248d352f4d25ddf Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 6 Dec 2023 09:17:36 -0700 Subject: [PATCH 058/195] Preparing to switch to bar-based delay --- libsq/core/progress/progress.go | 1 + 1 file changed, 1 insertion(+) diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 9137f691e..abf01d762 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -151,6 +151,7 @@ type Progress struct { // delayCh controls the rendering delay: rendering can // start as soon as delayCh is closed. + // TODO: Should delayCh be on Bar instead of Progress? delayCh <-chan struct{} // stopped is set to true when Stop is called. From abc0ac44c3486abbc75487fb54c3ca4c1f98d1a8 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 6 Dec 2023 13:34:41 -0700 Subject: [PATCH 059/195] wip --- cli/source.go | 16 +++++++--- drivers/xlsx/ingest.go | 2 +- libsq/core/lg/lga/lga.go | 1 + libsq/core/progress/io.go | 8 ++--- libsq/core/progress/progress.go | 56 +++++++++++++++++++++++---------- 5 files changed, 57 insertions(+), 26 deletions(-) diff --git a/cli/source.go b/cli/source.go index e96da1a4f..369558547 100644 --- a/cli/source.go +++ b/cli/source.go @@ -2,6 +2,7 @@ package cli import ( "context" + "os" "strings" "github.com/spf13/cobra" @@ -157,20 +158,25 @@ func processFlagActiveSchema(cmd *cobra.Command, activeSrc *source.Source) error // and returned. If the pipe has no data (size is zero), // then (nil,nil) is returned. func checkStdinSource(ctx context.Context, ru *run.Run) (*source.Source, error) { + log := lg.FromContext(ctx) cmd := ru.Cmd - f := ru.Stdin - info, err := f.Stat() + fi, err := f.Stat() if err != nil { return nil, errz.Wrap(err, "failed to get stat on stdin") } - if info.Size() <= 0 { - // Doesn't make sense to have zero-data pipe? just ignore. + switch { + case os.ModeNamedPipe&fi.Mode() > 0: + log.Info("Detected stdin pipe via os.ModeNamedPipe") + case fi.Size() > 0: + log.Info("Detected stdin redirect via size > 0", lga.Size, fi.Size()) + default: + log.Info("No stdin data detected") return nil, nil //nolint:nilnil } - // If we got this far, we have pipe input + // If we got this far, we have input from pipe or redirect. typ := drivertype.None if cmd.Flags().Changed(flag.IngestDriver) { diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index b7b508ad0..95c573869 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -142,7 +142,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, x continue } - time.Sleep(time.Millisecond * 100) + time.Sleep(time.Millisecond * 500) // FIXME: delete if err = ingestSheetToTable(ctx, destGrip, sheetTbls[i]); err != nil { return err } diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index fafba9873..b15d27d39 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -40,6 +40,7 @@ const ( Path = "path" Pid = "pid" Query = "query" + Size = "size" SLQ = "slq" SQL = "sql" Src = "src" diff --git a/libsq/core/progress/io.go b/libsq/core/progress/io.go index 240fdde6b..ba1832ee5 100644 --- a/libsq/core/progress/io.go +++ b/libsq/core/progress/io.go @@ -46,12 +46,12 @@ func NewWriter(ctx context.Context, msg string, size int64, w io.Writer) Writer return writerAdapter{contextio.NewWriter(ctx, w)} } - bar := pb.NewByteCounter(msg, size) + b := pb.NewByteCounter(msg, size) return &progCopier{progWriter{ ctx: ctx, - delayCh: pb.delayCh, + delayCh: b.delayCh, w: w, - b: bar, + b: b, }} } @@ -127,7 +127,7 @@ func NewReader(ctx context.Context, msg string, size int64, r io.Reader) io.Read b := pb.NewByteCounter(msg, size) pr := &progReader{ ctx: ctx, - delayCh: pb.delayCh, + delayCh: b.delayCh, r: r, b: b, } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index abf01d762..7cf9b0bdd 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -107,6 +107,7 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors colors: colors, cancelFn: cancelFn, bars: make([]*Bar, 0), + delay: delay, } p.pcInit = func() { @@ -116,18 +117,20 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors mpb.WithRefreshRate(refreshRate), mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } - if delay > 0 { - delayCh := renderDelay(p, delay) - p.delayCh = delayCh - } else { - delayCh := make(chan struct{}) - close(delayCh) - p.delayCh = delayCh - } + //if delay > 0 { + // delayCh := renderDelay(p, delay) + // p.delayCh = delayCh + //} else { + // delayCh := make(chan struct{}) + // close(delayCh) + // p.delayCh = delayCh + //} p.pc = mpb.NewWithContext(ctx, opts...) p.pcInit = nil } + + p.pcInit() return p } @@ -152,7 +155,9 @@ type Progress struct { // delayCh controls the rendering delay: rendering can // start as soon as delayCh is closed. // TODO: Should delayCh be on Bar instead of Progress? - delayCh <-chan struct{} + //delayCh <-chan struct{} + + delay time.Duration // stopped is set to true when Stop is called. stopped bool @@ -353,7 +358,7 @@ func (p *Progress) newBar(msg string, total int64, } if p.pc == nil { - p.pcInit() + p.pcInit() // FIXME: delete this } if total < 0 { @@ -389,12 +394,14 @@ func (p *Progress) newBar(msg string, total int64, b.incrStash.Store(0) } + b.delayCh = renderDelayBar(b, p.delay) + p.bars = append(p.bars, b) - select { - case <-p.delayCh: - b.initBarOnce.Do(b.initBar) - default: - } + //select { + //case <-b.delayCh: + // b.initBarOnce.Do(b.initBar) + //default: + //} return b } @@ -419,6 +426,8 @@ type Bar struct { initBarOnce *sync.Once initBar func() + delayCh <-chan struct{} + // incrStash holds the increment count until the // bar is fully initialized. incrStash *atomic.Int64 @@ -443,7 +452,7 @@ func (b *Bar) IncrBy(n int) { select { case <-b.p.ctx.Done(): return - case <-b.p.delayCh: + case <-b.delayCh: b.initBarOnce.Do(b.initBar) if b.bar != nil { b.bar.IncrBy(n) @@ -493,6 +502,21 @@ func renderDelay(p *Progress, d time.Duration) <-chan struct{} { return ch } +// renderDelay returns a channel that will be closed after d, +// at which point b will be initialized. +func renderDelayBar(b *Bar, d time.Duration) <-chan struct{} { + ch := make(chan struct{}) + t := time.NewTimer(d) + go func() { + defer close(ch) + defer t.Stop() + + <-t.C + b.initBarOnce.Do(b.initBar) + }() + return ch +} + func colorize(decorator decor.Decorator, c *color.Color) decor.Decorator { return decor.Meta(decorator, func(s string) string { return c.Sprint(s) From ce19a1c7d4ad69a314600fdb38f805a2d88a0770 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 6 Dec 2023 13:54:22 -0700 Subject: [PATCH 060/195] .gitignore --- .gitignore | 2 -- 1 file changed, 2 deletions(-) diff --git a/.gitignore b/.gitignore index f1b544f4a..ef74c33fd 100644 --- a/.gitignore +++ b/.gitignore @@ -55,5 +55,3 @@ goreleaser-test.sh /cli/test.db /*.db /.CHANGELOG.delta.md -go.work -go.work.sum From 78f4d86c512c5a7784b4aee18ae034c270aef2ed Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 6 Dec 2023 15:05:34 -0700 Subject: [PATCH 061/195] Progress cleanup --- cli/source.go | 13 +++----- libsq/core/progress/progress.go | 59 ++------------------------------- 2 files changed, 7 insertions(+), 65 deletions(-) diff --git a/cli/source.go b/cli/source.go index 369558547..8fd3fbb10 100644 --- a/cli/source.go +++ b/cli/source.go @@ -159,7 +159,6 @@ func processFlagActiveSchema(cmd *cobra.Command, activeSrc *source.Source) error // then (nil,nil) is returned. func checkStdinSource(ctx context.Context, ru *run.Run) (*source.Source, error) { log := lg.FromContext(ctx) - cmd := ru.Cmd f := ru.Stdin fi, err := f.Stat() if err != nil { @@ -179,22 +178,20 @@ func checkStdinSource(ctx context.Context, ru *run.Run) (*source.Source, error) // If we got this far, we have input from pipe or redirect. typ := drivertype.None - if cmd.Flags().Changed(flag.IngestDriver) { - val, _ := cmd.Flags().GetString(flag.IngestDriver) + if ru.Cmd.Flags().Changed(flag.IngestDriver) { + val, _ := ru.Cmd.Flags().GetString(flag.IngestDriver) typ = drivertype.Type(val) if ru.DriverRegistry.ProviderFor(typ) == nil { return nil, errz.Errorf("unknown driver type: %s", typ) } } - err = ru.Files.AddStdin(ctx, f) - if err != nil { + if err = ru.Files.AddStdin(ctx, f); err != nil { return nil, err } if typ == drivertype.None { - typ, err = ru.Files.DetectStdinType(ctx) - if err != nil { + if typ, err = ru.Files.DetectStdinType(ctx); err != nil { return nil, err } if typ == drivertype.None { @@ -230,7 +227,6 @@ func newSource(ctx context.Context, dp driver.Provider, typ drivertype.Type, han lga.Handle, handle, lga.Driver, typ, lga.Loc, source.RedactLocation(loc), - // lga.Opts, opts.Encode(), // FIXME: encode opts for debugging ) } @@ -246,7 +242,6 @@ func newSource(ctx context.Context, dp driver.Provider, typ drivertype.Type, han src := &source.Source{Handle: handle, Location: loc, Type: typ, Options: opts} - log.Debug("Validating provisional new data source", lga.Src, src) canonicalSrc, err := drvr.ValidateSource(src) if err != nil { return nil, err diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 7cf9b0bdd..1420c6058 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -117,14 +117,6 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors mpb.WithRefreshRate(refreshRate), mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } - //if delay > 0 { - // delayCh := renderDelay(p, delay) - // p.delayCh = delayCh - //} else { - // delayCh := make(chan struct{}) - // close(delayCh) - // p.delayCh = delayCh - //} p.pc = mpb.NewWithContext(ctx, opts...) p.pcInit = nil @@ -152,11 +144,8 @@ type Progress struct { // pcInit is the func that lazily initializes pc. pcInit func() - // delayCh controls the rendering delay: rendering can - // start as soon as delayCh is closed. - // TODO: Should delayCh be on Bar instead of Progress? - //delayCh <-chan struct{} - + // delay is the duration to wait before rendering a progress bar. + // This value is used for each bar created by this Progress. delay time.Duration // stopped is set to true when Stop is called. @@ -209,28 +198,6 @@ func (p *Progress) Stop() { p.pc.Wait() } -// initBars lazily initializes all bars in p.bars. -func (p *Progress) initBars() { - p.mu.Lock() - defer p.mu.Unlock() - - select { - case <-p.ctx.Done(): - return - default: - } - - if p.stopped { - return - } - - for _, b := range p.bars { - if !b.stopped { - b.initBarOnce.Do(b.initBar) - } - } -} - // NewUnitCounter returns a new indeterminate bar whose label // metric is the plural of the provided unit. The caller is ultimately // responsible for calling [Bar.Stop] on the returned Bar. However, @@ -395,13 +362,8 @@ func (p *Progress) newBar(msg string, total int64, } b.delayCh = renderDelayBar(b, p.delay) - p.bars = append(p.bars, b) - //select { - //case <-b.delayCh: - // b.initBarOnce.Do(b.initBar) - //default: - //} + return b } @@ -487,21 +449,6 @@ func (b *Bar) Stop() { b.bar.Wait() } -// renderDelay returns a channel that will be closed after d, -// at which point p.InitBars will be called. -func renderDelay(p *Progress, d time.Duration) <-chan struct{} { - ch := make(chan struct{}) - t := time.NewTimer(d) - go func() { - defer close(ch) - defer t.Stop() - - <-t.C - p.initBars() - }() - return ch -} - // renderDelay returns a channel that will be closed after d, // at which point b will be initialized. func renderDelayBar(b *Bar, d time.Duration) <-chan struct{} { From e49b922719ba5d2ded12fc7cc25f8dcb83719bf7 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 6 Dec 2023 15:08:50 -0700 Subject: [PATCH 062/195] go.mod cleanup --- go.mod | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 370414471..c05d62f48 100644 --- a/go.mod +++ b/go.mod @@ -90,19 +90,12 @@ require ( ) // See: https://github.com/vbauerster/mpb/issues/136 -require github.com/vbauerster/mpb/v8 v8.7.1-0.20231205062852-da3162c67234 // indirect -//require github.com/vbauerster/mpb/v8 v8.7.0 -// -//// See: https://github.com/vbauerster/mpb/issues/136 -//replace github.com/vbauerster/mpb/v8 v8.7.0 => ../sq-mpb +require github.com/vbauerster/mpb/v8 v8.7.1-0.20231205062852-da3162c67234 // See: https://github.com/djherbis/fscache/pull/21 require github.com/neilotoole/fscache v0.0.0-20231203162946-c9808f16552e -require ( - github.com/djherbis/stream v1.4.0 // indirect - -) +require github.com/djherbis/stream v1.4.0 // indirect // See: https://github.com/djherbis/stream/pull/11 replace github.com/djherbis/stream v1.4.0 => github.com/neilotoole/djherbis-stream v0.0.0-20231203160853-609f47afedda From 031a8bc13f9cfb4966c9edc871495a0ca3dadcb8 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 6 Dec 2023 20:54:22 -0700 Subject: [PATCH 063/195] Improved terminal detection --- cli/{term.go => terminal.go} | 22 ++++++++----- cli/terminal_windows.go | 63 ++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 8 deletions(-) rename cli/{term.go => terminal.go} (64%) create mode 100644 cli/terminal_windows.go diff --git a/cli/term.go b/cli/terminal.go similarity index 64% rename from cli/term.go rename to cli/terminal.go index 91f9cb120..3e32ff608 100644 --- a/cli/term.go +++ b/cli/terminal.go @@ -1,13 +1,15 @@ +//go:build !windows + package cli import ( "io" "os" - isatty "github.com/mattn/go-isatty" "golang.org/x/term" ) + // isTerminal returns true if w is a terminal. func isTerminal(w io.Writer) bool { switch v := w.(type) { @@ -19,18 +21,22 @@ func isTerminal(w io.Writer) bool { } // isColorTerminal returns true if w is a colorable terminal. +// It respects [NO_COLOR], [FORCE_COLOR] and TERM=dumb environment variables. +// +// [NO_COLOR]: https://no-color.org/ +// [FORCE_COLOR]: https://force-color.org/ func isColorTerminal(w io.Writer) bool { - if w == nil { + if os.Getenv("NO_COLOR") != "" { return false } - - // TODO: Add the improvements from jsoncolor: - // https://github.com/neilotoole/jsoncolor/pull/27 - if !isTerminal(w) { + if os.Getenv("FORCE_COLOR") != "" { + return true + } + if os.Getenv("TERM") == "dumb" { return false } - if os.Getenv("TERM") == "dumb" { + if w == nil { return false } @@ -39,7 +45,7 @@ func isColorTerminal(w io.Writer) bool { return false } - if isatty.IsCygwinTerminal(f.Fd()) { + if !term.IsTerminal(int(f.Fd())) { return false } diff --git a/cli/terminal_windows.go b/cli/terminal_windows.go new file mode 100644 index 000000000..75f42e1bc --- /dev/null +++ b/cli/terminal_windows.go @@ -0,0 +1,63 @@ +package terminal + +import ( + "io" + "os" + + "golang.org/x/sys/windows" +) + +// isTerminal returns true if w is a terminal. +func isTerminal(w io.Writer) bool { + switch v := w.(type) { + case *os.File: + return term.IsTerminal(int(v.Fd())) + default: + return false + } +} + +// isColorTerminal returns true if w is a colorable terminal. +// It respects [NO_COLOR], [FORCE_COLOR] and TERM=dumb environment variables. +// +// [NO_COLOR]: https://no-color.org/ +// [FORCE_COLOR]: https://force-color.org/ +func isColorTerminal(w io.Writer) bool { + if os.Getenv("NO_COLOR") != "" { + return false + } + if os.Getenv("FORCE_COLOR") != "" { + return true + } + if os.Getenv("TERM") == "dumb" { + return false + } + + if w == nil { + return false + } + + f, ok := w.(*os.File) + if !ok { + return false + } + fd := f.Fd() + + console := windows.Handle(fd) + var mode uint32 + if err := windows.GetConsoleMode(console, &mode); err != nil { + return false + } + + var want uint32 = windows.ENABLE_PROCESSED_OUTPUT | windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING + if (mode & want) == want { + return true + } + + mode |= want + if err := windows.SetConsoleMode(console, mode); err != nil { + return false + } + + return true +} From 91f19fbe0be6e6614565a870a1b9dd73d760cdb7 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 6 Dec 2023 23:17:44 -0700 Subject: [PATCH 064/195] Yet more tidying of source.Files --- cli/cmd_inspect.go | 4 + cli/source.go | 7 +- cli/terminal.go | 1 - cli/terminal_windows.go | 2 - drivers/csv/csv.go | 7 +- drivers/csv/detect_type.go | 3 +- drivers/json/json.go | 2 +- drivers/xlsx/grip.go | 2 +- drivers/xlsx/ingest.go | 3 +- libsq/core/ioz/ioz.go | 51 ------- libsq/core/lg/lga/lga.go | 1 + libsq/core/progress/progress.go | 3 +- libsq/source/detect.go | 78 ++++++----- libsq/source/files.go | 240 ++++++++++++++++++-------------- libsq/source/files_test.go | 4 +- 15 files changed, 202 insertions(+), 206 deletions(-) diff --git a/cli/cmd_inspect.go b/cli/cmd_inspect.go index 3a6b4789d..26aae5deb 100644 --- a/cli/cmd_inspect.go +++ b/cli/cmd_inspect.go @@ -125,7 +125,9 @@ func execInspect(cmd *cobra.Command, args []string) error { return err } + log.Error("before open") grip, err := ru.Grips.Open(ctx, src) + log.Error("after open") if err != nil { return errz.Wrapf(err, "failed to inspect %s", src.Handle) } @@ -203,6 +205,8 @@ func execInspect(cmd *cobra.Command, args []string) error { overviewOnly := cmdFlagIsSetTrue(cmd, flag.InspectOverview) + log.Debug("get source metadata") + srcMeta, err := grip.SourceMetadata(ctx, overviewOnly) if err != nil { return errz.Wrapf(err, "failed to read %s source metadata", src.Handle) diff --git a/cli/source.go b/cli/source.go index 8fd3fbb10..d7cc3623b 100644 --- a/cli/source.go +++ b/cli/source.go @@ -158,18 +158,19 @@ func processFlagActiveSchema(cmd *cobra.Command, activeSrc *source.Source) error // and returned. If the pipe has no data (size is zero), // then (nil,nil) is returned. func checkStdinSource(ctx context.Context, ru *run.Run) (*source.Source, error) { - log := lg.FromContext(ctx) f := ru.Stdin fi, err := f.Stat() if err != nil { return nil, errz.Wrap(err, "failed to get stat on stdin") } + mode := fi.Mode() + log := lg.FromContext(ctx).With(lga.File, fi.Name(), lga.Size, fi.Size(), "mode", mode.String()) switch { - case os.ModeNamedPipe&fi.Mode() > 0: + case os.ModeNamedPipe&mode > 0: log.Info("Detected stdin pipe via os.ModeNamedPipe") case fi.Size() > 0: - log.Info("Detected stdin redirect via size > 0", lga.Size, fi.Size()) + log.Info("Detected stdin redirect via size > 0") default: log.Info("No stdin data detected") return nil, nil //nolint:nilnil diff --git a/cli/terminal.go b/cli/terminal.go index 3e32ff608..24b3cbe5b 100644 --- a/cli/terminal.go +++ b/cli/terminal.go @@ -9,7 +9,6 @@ import ( "golang.org/x/term" ) - // isTerminal returns true if w is a terminal. func isTerminal(w io.Writer) bool { switch v := w.(type) { diff --git a/cli/terminal_windows.go b/cli/terminal_windows.go index 75f42e1bc..439009d04 100644 --- a/cli/terminal_windows.go +++ b/cli/terminal_windows.go @@ -3,8 +3,6 @@ package terminal import ( "io" "os" - - "golang.org/x/sys/windows" ) // isTerminal returns true if w is a terminal. diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index 0b985e862..756caf267 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -89,6 +89,7 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, er return nil, err } + log.Error("open ingest done", lga.Err, err) return g, nil } @@ -156,7 +157,11 @@ func (p *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Tab // SourceMetadata implements driver.Grip. func (p *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + log := lg.FromContext(ctx) + + log.Debug("before impl.SourceMetadata") md, err := p.impl.SourceMetadata(ctx, noSchema) + log.Debug("after impl.SourceMetadata", lga.Err, err) if err != nil { return nil, err } @@ -170,7 +175,7 @@ func (p *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return nil, err } - md.Size, err = p.files.Size(ctx, p.src) + md.Size, err = p.files.Filesize(ctx, p.src) if err != nil { return nil, err } diff --git a/drivers/csv/detect_type.go b/drivers/csv/detect_type.go index a569a2f9a..bd1f72002 100644 --- a/drivers/csv/detect_type.go +++ b/drivers/csv/detect_type.go @@ -76,9 +76,8 @@ const ( // legitimate CSV, where a score <= 0 is not CSV, a score >= 1 is definitely CSV. func isCSV(ctx context.Context, cr *csv.Reader) (score float32) { start := time.Now() - lg.FromContext(ctx).Debug("isCSV invoked", lga.Timestamp, start) defer func() { - lg.FromContext(ctx).Debug("isCSV complete", "elapsed", time.Since(start), "score", score) + lg.FromContext(ctx).Debug("CSV detection complete", lga.Elapsed, time.Since(start), lga.Score, score) }() const ( maxRecords int = 100 diff --git a/drivers/json/json.go b/drivers/json/json.go index 7db4decf2..a2f229720 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -204,7 +204,7 @@ func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return nil, err } - md.Size, err = g.files.Size(ctx, g.src) + md.Size, err = g.files.Filesize(ctx, g.src) if err != nil { return nil, err } diff --git a/drivers/xlsx/grip.go b/drivers/xlsx/grip.go index 1df7b0b67..074a1616e 100644 --- a/drivers/xlsx/grip.go +++ b/drivers/xlsx/grip.go @@ -53,7 +53,7 @@ func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou } md.FQName = md.Name - if md.Size, err = g.files.Size(ctx, g.src); err != nil { + if md.Size, err = g.files.Filesize(ctx, g.src); err != nil { return nil, err } diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 95c573869..b1c5c741b 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -7,8 +7,6 @@ import ( "strings" "time" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/samber/lo" excelize "github.com/xuri/excelize/v2" "golang.org/x/sync/errgroup" @@ -20,6 +18,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/loz" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/driver" diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 21cf50cbe..43eed24b2 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -11,7 +11,6 @@ import ( "path/filepath" "strings" "sync" - "sync/atomic" "time" "github.com/a8m/tree" @@ -330,56 +329,6 @@ func (c nopWriteCloserReaderFrom) ReadFrom(r io.Reader) (int64, error) { return c.Writer.(io.ReaderFrom).ReadFrom(r) } -// NewWrittenWriter returns a writer that counts the number of bytes -// written to the underlying writer. The number of bytes written can -// be obtained via [WrittenWriter.Written], which blocks until writing -// has concluded. -func NewWrittenWriter(w io.WriteCloser) *WrittenWriter { - return &WrittenWriter{ - w: w, - c: &atomic.Int64{}, - done: make(chan struct{}), - } -} - -var _ io.Writer = (*WrittenWriter)(nil) - -// WrittenWriter is an io.WriteCloser that counts the number of bytes -// written to the underlying writer. The number of bytes written can -// be obtained via [WrittenWriter.Written], which blocks until writing -// has concluded. -type WrittenWriter struct { - c *atomic.Int64 - w io.WriteCloser - doneOnce sync.Once - done chan struct{} -} - -// Written returns the number of bytes written to the underlying -// writer, blocking until writing concludes, either via invocation of -// Close, or via an error in Write. -func (w *WrittenWriter) Written() int64 { - <-w.done - return w.c.Load() -} - -// Close implements io.WriteCloser. -func (w *WrittenWriter) Close() error { - closeErr := w.w.Close() - w.doneOnce.Do(func() { close(w.done) }) - return closeErr -} - -// Write implements io.WriterCloser. -func (w *WrittenWriter) Write(p []byte) (n int, err error) { - n, err = w.w.Write(p) - w.c.Add(int64(n)) - if err != nil { - w.doneOnce.Do(func() { close(w.done) }) - } - return n, err -} - // DirSize returns total size of all regular files in path. func DirSize(path string) (int64, error) { var size int64 diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index b15d27d39..27f721b02 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -40,6 +40,7 @@ const ( Path = "path" Pid = "pid" Query = "query" + Score = "score" Size = "size" SLQ = "slq" SQL = "sql" diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 1420c6058..3d2840c13 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -22,8 +22,6 @@ import ( "sync/atomic" "time" - "github.com/neilotoole/sq/libsq/core/stringz" - humanize "github.com/dustin/go-humanize" "github.com/dustin/go-humanize/english" "github.com/fatih/color" @@ -31,6 +29,7 @@ import ( "github.com/vbauerster/mpb/v8/decor" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/stringz" ) type ctxKey struct{} diff --git a/libsq/source/detect.go b/libsq/source/detect.go index d7b85ef36..84b86d100 100644 --- a/libsq/source/detect.go +++ b/libsq/source/detect.go @@ -17,32 +17,26 @@ import ( "github.com/neilotoole/sq/libsq/source/drivertype" ) +// DriverDetectFunc interrogates a byte stream to determine +// the source driver type. A score is returned indicating +// the confidence that the driver type has been detected. +// A score <= 0 is failure, a score >= 1 is success; intermediate +// values indicate some level of confidence. +// An error is returned only if an IO problem occurred. +// The implementation gets access to the byte stream by invoking openFn, +// and is responsible for closing any reader it opens. +type DriverDetectFunc func(ctx context.Context, openFn FileOpenFunc) ( + detected drivertype.Type, score float32, err error) + +var _ DriverDetectFunc = DetectMagicNumber + // AddDriverDetectors adds driver type detectors. func (fs *Files) AddDriverDetectors(detectFns ...DriverDetectFunc) { + fs.mu.Lock() + defer fs.mu.Unlock() fs.detectFns = append(fs.detectFns, detectFns...) } -// DetectStdinType detects the type of stdin as previously added -// by AddStdin. An error is returned if AddStdin was not -// first invoked. If the type cannot be detected, TypeNone and -// nil are returned. -func (fs *Files) DetectStdinType(ctx context.Context) (drivertype.Type, error) { - if !fs.fcache.Exists(StdinHandle) { - return drivertype.None, errz.New("must invoke Files.AddStdin before invoking DetectStdinType") - } - - typ, ok, err := fs.detectType(ctx, StdinHandle) - if err != nil { - return drivertype.None, err - } - - if !ok { - return drivertype.None, nil - } - - return typ, nil -} - // DriverType returns the driver type of loc. func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, error) { log := lg.FromContext(ctx).With(lga.Loc, loc) @@ -67,6 +61,8 @@ func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, e } } + fs.mu.Lock() + defer fs.mu.Unlock() // Fall back to the byte detectors typ, ok, err := fs.detectType(ctx, loc) if err != nil { @@ -94,9 +90,6 @@ func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Typ resultCh := make(chan result, len(fs.detectFns)) openFn := func(ctx context.Context) (io.ReadCloser, error) { - fs.mu.Lock() - defer fs.mu.Unlock() - return fs.newReader(ctx, loc) } @@ -160,19 +153,6 @@ func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Typ return drivertype.None, false, nil } -// DriverDetectFunc interrogates a byte stream to determine -// the source driver type. A score is returned indicating -// the confidence that the driver type has been detected. -// A score <= 0 is failure, a score >= 1 is success; intermediate -// values indicate some level of confidence. -// An error is returned only if an IO problem occurred. -// The implementation gets access to the byte stream by invoking openFn, -// and is responsible for closing any reader it opens. -type DriverDetectFunc func(ctx context.Context, openFn FileOpenFunc) ( - detected drivertype.Type, score float32, err error) - -var _ DriverDetectFunc = DetectMagicNumber - // DetectMagicNumber is a DriverDetectFunc that uses an external // pkg (h2non/filetype) to detect the "magic number" from // the start of files. @@ -217,3 +197,27 @@ func DetectMagicNumber(ctx context.Context, openFn FileOpenFunc, return typeSL3, 1.0, nil } } + +// DetectStdinType detects the type of stdin as previously added +// by AddStdin. An error is returned if AddStdin was not +// first invoked. If the type cannot be detected, TypeNone and +// nil are returned. +func (fs *Files) DetectStdinType(ctx context.Context) (drivertype.Type, error) { + fs.mu.Lock() + defer fs.mu.Unlock() + + if !fs.fscache.Exists(StdinHandle) { + return drivertype.None, errz.New("must invoke Files.AddStdin before invoking DetectStdinType") + } + + typ, ok, err := fs.detectType(ctx, StdinHandle) + if err != nil { + return drivertype.None, err + } + + if !ok { + return drivertype.None, nil + } + + return typ, nil +} diff --git a/libsq/source/files.go b/libsq/source/files.go index 6c5186f50..809699aa9 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -7,6 +7,7 @@ import ( "net/url" "os" "path/filepath" + "strconv" "sync" "time" @@ -20,6 +21,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/core/stringz" ) // Files is the centralized API for interacting with files. @@ -35,22 +37,24 @@ import ( // if we're reading long-running pipe from stdin). This entire thing // needs to be revisited. Maybe Files even becomes a fs.FS. type Files struct { - cacheDir string - tempDir string - log *slog.Logger - mu sync.Mutex - clnup *cleanup.Cleanup - fcache *fscache.FSCache + mu sync.Mutex + log *slog.Logger + cacheDir string + tempDir string + clnup *cleanup.Cleanup + + // fscache is used to cache files, providing convenient access + // to multiple readers via Files.newReader. + fscache *fscache.FSCache + + // fscacheEntryMetas contains metadata about fscache entries. + // Entries are added by Files.addStdin, and consumed by + // Files.Filesize. + fscacheEntryMetas map[string]*fscacheEntryMeta + + // detectFns is the set of functions that can detect + // the type of a file. detectFns []DriverDetectFunc - - // stdinLength is a func that returns number of bytes read from stdin. - // It is nil if stdin has not been read. The func may block until reading - // of stdin has completed. - // - // FIXME: This should probably be a map of location to length func, - // because downloaded files can use this mechanism too. - // See Files.Size. - stdinLength func() int64 } // NewFiles returns a new Files instance. @@ -63,13 +67,18 @@ func NewFiles(ctx context.Context, tmpDir, cacheDir string) (*Files, error) { } fs := &Files{ - cacheDir: cacheDir, - tempDir: tmpDir, - clnup: cleanup.New(), - log: lg.FromContext(ctx), + cacheDir: cacheDir, + fscacheEntryMetas: make(map[string]*fscacheEntryMeta), + tempDir: tmpDir, + clnup: cleanup.New(), + log: lg.FromContext(ctx), } - fcacheTmpDir := filepath.Join(cacheDir, "fscache") + // We want a unique dir for each execution. Note that fcache is deleted + // on cleanup (unless something bad happens and sq doesn't + // get a chance to clean up). But, why take the chance; we'll just give + // fcache a unique dir each time. + fcacheTmpDir := filepath.Join(cacheDir, "fscache", strconv.Itoa(os.Getpid()), stringz.Uniq32()) if err := ioz.RequireDir(fcacheTmpDir); err != nil { return nil, errz.Err(err) } @@ -80,13 +89,13 @@ func NewFiles(ctx context.Context, tmpDir, cacheDir string) (*Files, error) { } fs.clnup.AddE(fcache.Clean) - fs.fcache = fcache + fs.fscache = fcache return fs, nil } -// Size returns the file size of src.Location. If the source is being +// Filesize returns the file size of src.Location. If the source is being // loaded asynchronously, this function may block until loading completes. -func (fs *Files) Size(ctx context.Context, src *Source) (size int64, err error) { +func (fs *Files) Filesize(ctx context.Context, src *Source) (size int64, err error) { locTyp := getLocType(src.Location) switch locTyp { case locTypeLocalFile: @@ -96,100 +105,126 @@ func (fs *Files) Size(ctx context.Context, src *Source) (size int64, err error) return 0, errz.Err(err) } return fi.Size(), nil + case locTypeRemoteFile: // FIXME: implement remote file size. return 0, errz.Errorf("remote file size not implemented: %s", src.Location) + case locTypeSQL: return 0, errz.Errorf("cannot get size of SQL source: %s", src.Handle) + case locTypeStdin: - // Special handling for stdin. - if fs.stdinLength == nil { - return 0, errz.Errorf("stdin not yet added") + fs.mu.Lock() + entryMeta, ok := fs.fscacheEntryMetas[StdinHandle] + fs.mu.Unlock() + if !ok { + return 0, errz.Errorf("stdin not present in cache") } select { case <-ctx.Done(): return 0, ctx.Err() - default: - return fs.stdinLength(), nil + case <-entryMeta.done: + return entryMeta.written, entryMeta.err } + default: return 0, errz.Errorf("unknown source location type: %s", RedactLocation(src.Location)) } } +// fscacheEntryMeta contains metadata about a fscache entry. +// When the cache entry has been filled, the done channel +// is closed and the written and err fields are set. +// This mechanism allows Files.Filesize to block until +// the asynchronous filling of the cache entry has completed. +type fscacheEntryMeta struct { + key string + done chan struct{} + written int64 + err error +} + // AddStdin copies f to fs's cache: the stdin data in f // is later accessible via fs.Open(src) where src.Handle // is StdinHandle; f's type can be detected via DetectStdinType. // Note that f is closed by this method. -// -// FIXME: AddStdin is probably not necessary, we can just do it -// on the fly in newReader? Or do we provide this because "stdin" -// can be something other than os.Stdin, e.g. via a flag? func (fs *Files) AddStdin(ctx context.Context, f *os.File) error { fs.mu.Lock() defer fs.mu.Unlock() - err := fs.addStdin(ctx, f) // f is closed by addStdin - return errz.Wrap(err, "failed to read stdin") + err := fs.addStdin(ctx, StdinHandle, f) // f is closed by addStdin + return errz.Wrapf(err, "failed to add %s to fscache", StdinHandle) } -// addStdin synchronously copies f (stdin) to fs's cache. f is closed +// addStdin synchronously copies f to fs's cache. f is closed // when the async copy completes. This method should only be used -// for stdin; for regular files, use Files.addFile. -func (fs *Files) addStdin(ctx context.Context, f *os.File) error { - log := lg.FromContext(ctx).With(lga.File, f.Name()) +// for stdin; for regular files, use Files.addRegularFile. +func (fs *Files) addStdin(ctx context.Context, handle string, f *os.File) error { + log := lg.FromContext(ctx).With(lga.Handle, handle, lga.File, f.Name()) - if fs.stdinLength != nil { - return errz.Errorf("stdin already added") + if _, ok := fs.fscacheEntryMetas[handle]; ok { + return errz.Errorf("%s already added to fscache", handle) } - // Special handling for stdin - r, w, wErrFn, err := fs.fcache.GetWithErr(StdinHandle) + cacheRdr, cacheWrtr, cacheWrtrErrFn, err := fs.fscache.GetWithErr(handle) if err != nil { return errz.Err(err) } - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, cacheRdr) - if w == nil { + if cacheWrtr == nil { // Shouldn't happen - return errz.Errorf("no cache writer for %s", StdinHandle) - } + if cacheRdr != nil { + return errz.Errorf("no fscache writer for %s (but fscache reader exists when it shouldn't)", handle) + } - lw := ioz.NewWrittenWriter(w) - fs.stdinLength = lw.Written + return errz.Errorf("no fscache writer for %s", handle) + } - cr := contextio.NewReader(ctx, f) - pw := progress.NewWriter(ctx, "Reading stdin", -1, lw) + // We create an entry meta for this handle. This entry will be + // filled asynchronously in the ioz.CopyAsync callback below. + // The entry can then be consumed by Files.Filesize. + entryMeta := &fscacheEntryMeta{ + key: handle, + done: make(chan struct{}), + } + fs.fscacheEntryMetas[handle] = entryMeta start := time.Now() - ioz.CopyAsync(pw, cr, func(written int64, err error) { - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, f) - elapsed := time.Since(start) - if err == nil { - log.Debug("Async stdin cache fill: completed", lga.Copied, written, lga.Elapsed, elapsed) - lg.WarnIfCloseError(log, "Close stdin cache", w) + pw := progress.NewWriter(ctx, "Reading "+handle, -1, cacheWrtr) + ioz.CopyAsync(pw, contextio.NewReader(ctx, f), + func(written int64, err error) { + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, f) + entryMeta.written = written + entryMeta.err = err + close(entryMeta.done) + + elapsed := time.Since(start) + if err == nil { + log.Info("Async fscache fill: completed", lga.Copied, written, lga.Elapsed, elapsed) + lg.WarnIfCloseError(log, "Close fscache writer", cacheWrtr) + pw.Stop() + return + } + + log.Error("Async fscache fill: failure", lga.Copied, written, lga.Elapsed, elapsed, lga.Err, err) + pw.Stop() - return - } + cacheWrtrErrFn(err) + // We deliberately don't close cacheWrtr here, + // because cacheWrtrErrFn handles that work. + }, + ) - log.Error("Async stdin cache fill: failure", - lga.Err, err, - lga.Copied, written, - lga.Elapsed, elapsed, - ) - pw.Stop() - wErrFn(err) - // We deliberately don't close "w" here, because wErrFn handles that work. - }) - log.Debug("Async stdin cache fill: dispatched") + log.Debug("Async fscache fill: dispatched") return nil } -// addFile maps f to fs's cache, returning a reader which the +// addRegularFile maps f to fs's cache, returning a reader which the // caller is responsible for closing. f is closed by this method. // Do not add stdin via this function; instead use addStdin. -func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { +func (fs *Files) addRegularFile(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { log := lg.FromContext(ctx) log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) @@ -198,42 +233,44 @@ func (fs *Files) addFile(ctx context.Context, f *os.File, key string) (fscache.R if key == StdinHandle { // This is a programming error; the caller should have // instead invoked addStdin. Probably should panic here. - return nil, errz.New("illegal to add stdin via Files.addFile") + return nil, errz.New("illegal to add stdin via Files.addRegularFile") } - if fs.fcache.Exists(key) { + if fs.fscache.Exists(key) { return nil, errz.Errorf("file already exists in cache: %s", key) } - if err := fs.fcache.MapFile(f.Name()); err != nil { + if err := fs.fscache.MapFile(f.Name()); err != nil { return nil, errz.Wrapf(err, "failed to map file into fscache: %s", f.Name()) } - r, _, err := fs.fcache.Get(key) + r, _, err := fs.fscache.Get(key) return r, errz.Err(err) } // Filepath returns the file path of src.Location. // An error is returned the source's driver type // is not a file type (i.e. it is a SQL driver). +// FIXME: Implement Files.Filepath fully. func (fs *Files) Filepath(_ context.Context, src *Source) (string, error) { - fs.mu.Lock() - defer fs.mu.Unlock() - loc := src.Location + // fs.mu.Lock() + // defer fs.mu.Unlock() - if fp, ok := isFpath(loc); ok { - return fp, nil - } - - u, ok := httpURL(loc) - if !ok { - return "", errz.Errorf("not a valid file location: %s", loc) + switch getLocType(src.Location) { + case locTypeLocalFile: + return src.Location, nil + case locTypeRemoteFile: + // FIXME: implement remote file location. + // It's a remote file. We really should download it here. + // FIXME: implement downloading. + return "", errz.Errorf("not implemented for remote source: %s", src.Handle) + case locTypeSQL: + return "", errz.Errorf("cannot get filepath of SQL source: %s", src.Handle) + case locTypeStdin: + return "", errz.Errorf("cannot get filepath of stdin source: %s", src.Handle) + default: + return "", errz.Errorf("unknown source location type for %s: %s", src.Handle, RedactLocation(src.Location)) } - - _ = u - // It's a remote file. We really should download it here. - // FIXME: implement downloading. - return "", errz.Errorf("Files.Filepath not implemented for remote files: %s", loc) } // Open returns a new io.ReadCloser for src.Location. @@ -248,7 +285,7 @@ func (fs *Files) Open(ctx context.Context, src *Source) (io.ReadCloser, error) { } // OpenFunc returns a func that invokes fs.Open for src.Location. -func (fs *Files) OpenFunc(src *Source) func(ctx context.Context) (io.ReadCloser, error) { +func (fs *Files) OpenFunc(src *Source) FileOpenFunc { return func(ctx context.Context) (io.ReadCloser, error) { return fs.Open(ctx, src) } @@ -256,7 +293,9 @@ func (fs *Files) OpenFunc(src *Source) func(ctx context.Context) (io.ReadCloser, // ReadAll is a convenience method to read the bytes of a source. func (fs *Files) ReadAll(ctx context.Context, src *Source) ([]byte, error) { + // fs.mu.Lock() r, err := fs.newReader(ctx, src.Location) + // fs.mu.Unlock() if err != nil { return nil, err } @@ -276,7 +315,6 @@ func (fs *Files) ReadAll(ctx context.Context, src *Source) ([]byte, error) { func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, error) { log := lg.FromContext(ctx).With(lga.Loc, loc) - log.Debug("Files.newReader", lga.Loc, loc) locTyp := getLocType(loc) switch locTyp { @@ -285,7 +323,7 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro case locTypeSQL: return nil, errz.Errorf("cannot read SQL source: %s", loc) case locTypeStdin: - r, w, err := fs.fcache.Get(StdinHandle) + r, w, err := fs.fscache.Get(StdinHandle) if err != nil { return nil, errz.Err(err) } @@ -298,8 +336,8 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro // Well, it's either a local or remote file. // Let's see if it's cached. - if fs.fcache.Exists(loc) { - r, _, err := fs.fcache.Get(loc) + if fs.fscache.Exists(loc) { + r, _, err := fs.fscache.Get(loc) if err != nil { return nil, err } @@ -313,8 +351,8 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro if err != nil { return nil, errz.Err(err) } - // fs.addFile closes f, so we don't have to do it. - r, err := fs.addFile(ctx, f, loc) + // fs.addRegularFile closes f, so we don't have to do it. + r, err := fs.addRegularFile(ctx, f, loc) if err != nil { return nil, err } @@ -324,7 +362,7 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro // It's an uncached remote file. if loc == StdinHandle { - r, w, err := fs.fcache.Get(StdinHandle) + r, w, err := fs.fscache.Get(StdinHandle) log.Debug("Returned from fs.fcache.Get", lga.Err, err) if err != nil { return nil, errz.Err(err) @@ -336,8 +374,8 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro return r, nil } - if !fs.fcache.Exists(loc) { - r, _, err := fs.fcache.Get(loc) + if !fs.fscache.Exists(loc) { + r, _, err := fs.fscache.Get(loc) if err != nil { return nil, err } @@ -351,8 +389,8 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro return nil, err } - // Note that addFile closes f - r, err := fs.addFile(ctx, f, loc) + // Note that addRegularFile closes f + r, err := fs.addRegularFile(ctx, f, loc) if err != nil { return nil, err } diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index eb9990ef7..55f9ad90c 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -245,7 +245,7 @@ func TestFiles_Size(t *testing.T) { th := testh.New(t) fs := th.Files() - gotSize, err := fs.Size(th.Context, &source.Source{ + gotSize, err := fs.Filesize(th.Context, &source.Source{ Handle: stringz.UniqSuffix("@h"), Location: f.Name(), }) @@ -257,7 +257,7 @@ func TestFiles_Size(t *testing.T) { // Verify that this works with @stdin as well require.NoError(t, fs.AddStdin(th.Context, f2)) - gotSize2, err := fs.Size(th.Context, &source.Source{ + gotSize2, err := fs.Filesize(th.Context, &source.Source{ Handle: "@stdin", Location: "@stdin", }) From 893b7140fec9d69b0373484cef04445ef35a1a12 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 7 Dec 2023 00:57:15 -0700 Subject: [PATCH 065/195] wip: download/checksum --- libsq/core/ioz/{ => checksum}/checksum.go | 47 ++++----- libsq/core/ioz/ioz_test.go | 11 +- libsq/core/lg/lga/lga.go | 1 + libsq/driver/grips.go | 11 +- libsq/source/download.go | 121 ++++++++++++++++++++++ libsq/source/download_test.go | 64 ++++++++++++ 6 files changed, 219 insertions(+), 36 deletions(-) rename libsq/core/ioz/{ => checksum}/checksum.go (60%) create mode 100644 libsq/source/download_test.go diff --git a/libsq/core/ioz/checksum.go b/libsq/core/ioz/checksum/checksum.go similarity index 60% rename from libsq/core/ioz/checksum.go rename to libsq/core/ioz/checksum/checksum.go index 7e170eadb..6c9910fb1 100644 --- a/libsq/core/ioz/checksum.go +++ b/libsq/core/ioz/checksum/checksum.go @@ -1,4 +1,4 @@ -package ioz +package checksum import ( "bufio" @@ -16,10 +16,10 @@ import ( // Checksum is a checksum of a file. type Checksum string -// FileChecksum returns a checksum of the file at path. +// ForFile returns a checksum of the file at path. // The checksum is based on the file's name, size, mode, and // modification time. File contents are not read. -func FileChecksum(path string) (Checksum, error) { +func ForFile(path string) (Checksum, error) { fi, err := os.Stat(path) if err != nil { return "", errz.Wrap(err, "calculate file checksum") @@ -36,36 +36,36 @@ func FileChecksum(path string) (Checksum, error) { return Checksum(fmt.Sprintf("%x", sum)), nil } -// WriteChecksum appends a checksum line to w, including -// a newline. The format is: +// Write appends a checksum line to w, including +// a newline. The typical format is: // -// -// da1f14c16c09bebbc452108d9ab193541f2e96515aefcb7745fee5197c343106 file.txt +// +// da1f14c16c09bebbc452108d9ab193541f2e96515aefcb7745fee5197c343106 file.txt // -// Use FileChecksum to calculate a checksum, and ReadChecksums -// to read this format. -func WriteChecksum(w io.Writer, sum Checksum, name string) error { +// However, the checksum be any string value. Use ForFile to calculate +// a checksum, and Read to read this format. +func Write(w io.Writer, sum Checksum, name string) error { _, err := fmt.Fprintf(w, "%s %s\n", sum, name) return errz.Err(err) } -// WriteChecksumFile writes a single {checksum,name} to path, overwriting +// WriteFile writes a single {checksum,name} to path, overwriting // the previous contents. // -// See: WriteChecksum. -func WriteChecksumFile(path string, sum Checksum, name string) error { +// See: Write. +func WriteFile(path string, sum Checksum, name string) error { f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) if err != nil { return errz.Wrap(err, "write checksum file") } defer func() { _ = f.Close() }() - return WriteChecksum(f, sum, name) + return Write(f, sum, name) } -// ReadChecksumsFile reads a checksum file from path. +// ReadFile reads a checksum file from path. // -// See ReadChecksums for details. -func ReadChecksumsFile(path string) (map[string]Checksum, error) { +// See Read for details. +func ReadFile(path string) (map[string]Checksum, error) { f, err := os.Open(path) if err != nil { return nil, errz.Err(err) @@ -73,14 +73,14 @@ func ReadChecksumsFile(path string) (map[string]Checksum, error) { defer func() { _ = f.Close() }() - return ReadChecksums(f) + return Read(f) } -// ReadChecksums reads checksums lines from r, returning a map +// Read reads checksums lines from r, returning a map // of checksums keyed by name. Empty lines, and lines beginning // with "#" (comments) are ignored. This function is the -// inverse of WriteChecksum. -func ReadChecksums(r io.Reader) (map[string]Checksum, error) { +// inverse of Write. +func Read(r io.Reader) (map[string]Checksum, error) { sums := map[string]Checksum{} sc := bufio.NewScanner(r) @@ -90,11 +90,6 @@ func ReadChecksums(r io.Reader) (map[string]Checksum, error) { continue } - if strings.Contains(line, "INTEGER") { // FIXME: delete - x := true - _ = x - } - parts := strings.SplitN(line, " ", 2) if len(parts) != 2 { return nil, errz.Errorf("invalid checksum line: %q", line) diff --git a/libsq/core/ioz/ioz_test.go b/libsq/core/ioz/ioz_test.go index 6b1f5b24f..c9acfdac3 100644 --- a/libsq/core/ioz/ioz_test.go +++ b/libsq/core/ioz/ioz_test.go @@ -2,6 +2,7 @@ package ioz_test import ( "bytes" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "io" "os" "sync" @@ -33,12 +34,12 @@ func TestChecksums(t *testing.T) { buf := &bytes.Buffer{} - gotSum1, err := ioz.FileChecksum(f.Name()) + gotSum1, err := checksum.ForFile(f.Name()) require.NoError(t, err) t.Logf("gotSum1: %s %s", gotSum1, f.Name()) - require.NoError(t, ioz.WriteChecksum(buf, gotSum1, f.Name())) + require.NoError(t, checksum.Write(buf, gotSum1, f.Name())) - gotSums, err := ioz.ReadChecksums(bytes.NewReader(buf.Bytes())) + gotSums, err := checksum.Read(bytes.NewReader(buf.Bytes())) require.NoError(t, err) require.Len(t, gotSums, 1) require.Equal(t, gotSum1, gotSums[f.Name()]) @@ -49,10 +50,10 @@ func TestChecksums(t *testing.T) { _, err = io.WriteString(f, "more huzzah") require.NoError(t, err) assert.NoError(t, f.Close()) - gotSum2, err := ioz.FileChecksum(f.Name()) + gotSum2, err := checksum.ForFile(f.Name()) require.NoError(t, err) t.Logf("gotSum2: %s %s", gotSum2, f.Name()) - require.NoError(t, ioz.WriteChecksum(buf, gotSum1, f.Name())) + require.NoError(t, checksum.Write(buf, gotSum1, f.Name())) require.NotEqual(t, gotSum1, gotSum2) } diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index 27f721b02..34fa638e4 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -18,6 +18,7 @@ const ( DB = "db" DBType = "db_type" Dest = "dest" + Dir = "dir" Driver = "driver" DefaultTo = "default_to" Elapsed = "elapsed" diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 558e02c4c..37cc2da50 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -3,6 +3,7 @@ package driver import ( "context" "errors" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "log/slog" "path/filepath" "strings" @@ -295,14 +296,14 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, log.Info("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) // Write the checksums file. - var sum ioz.Checksum - if sum, err = ioz.FileChecksum(ingestFilePath); err != nil { + var sum checksum.Checksum + if sum, err = checksum.ForFile(ingestFilePath); err != nil { log.Warn("Failed to compute checksum for source file; caching not in effect", lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) return impl, nil //nolint:nilerr } - if err = ioz.WriteChecksumFile(checksumsPath, sum, ingestFilePath); err != nil { + if err = checksum.WriteFile(checksumsPath, sum, ingestFilePath); err != nil { log.Warn("Failed to write checksum; file caching not in effect", lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) } @@ -378,7 +379,7 @@ func (gs *Grips) openCachedFor(ctx context.Context, src *source.Source) (Grip, b return nil, false, nil } - mChecksums, err := ioz.ReadChecksumsFile(checksumsPath) + mChecksums, err := checksum.ReadFile(checksumsPath) if err != nil { return nil, false, err } @@ -406,7 +407,7 @@ func (gs *Grips) openCachedFor(ctx context.Context, src *source.Source) (Grip, b return nil, false, nil } - srcChecksum, err := ioz.FileChecksum(srcFilepath) + srcChecksum, err := checksum.ForFile(srcFilepath) if err != nil { return nil, false, err } diff --git a/libsq/source/download.go b/libsq/source/download.go index 764d1f127..f422dc2f2 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -2,13 +2,134 @@ package source import ( "context" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "golang.org/x/exp/maps" + "log/slog" + "net/http" "net/url" "os" + "path/filepath" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/source/fetcher" ) +func newDownloader(log *slog.Logger, srcCacheDir, url string) *downloader { + downloadDir := filepath.Join(srcCacheDir, "download") + return &downloader{ + log: log.With(lga.URL, url, "download_dir", downloadDir), + srcCacheDir: srcCacheDir, + downloadDir: downloadDir, + checksumFile: filepath.Join(srcCacheDir, "download.checksum.txt"), + url: url, + } +} + +type downloader struct { + log *slog.Logger + srcCacheDir string + downloadDir string + checksumFile string + url string +} + +func (d *downloader) Cached() (ok bool, sum checksum.Checksum, fp string) { + fi, err := os.Stat(d.downloadDir) + if err != nil { + d.log.Debug("not cached: can't stat download dir") + return false, "", "" + } + if !fi.IsDir() { + d.log.Error("not cached: download dir is not a dir") + return false, "", "" + } + + fi, err = os.Stat(d.checksumFile) + if err != nil { + d.log.Debug("not cached: can't stat download checksum file") + return false, "", "" + } + + checksums, err := checksum.ReadFile(d.checksumFile) + if err != nil { + d.log.Debug("not cached: can't read download checksum file") + return false, "", "" + } + + if len(checksums) != 1 { + d.log.Debug("not cached: download checksum file has unexpected number of entries") + return false, "", "" + } + + key := maps.Keys(checksums)[0] + sum = checksums[key] + if len(sum) == 0 { + d.log.Debug("not cached: checksum file has empty checksum", lga.File, key) + return false, "", "" + } + + downloadFile := filepath.Join(d.downloadDir, key) + + fi, err = os.Stat(downloadFile) + if err != nil { + d.log.Debug("not cached: can't stat download file referenced in checksum file", lga.File, key) + return false, "", "" + } + + d.log.Info("found cached file", lga.File, key) + return true, sum, downloadFile +} + +// fetchHTTPHeader fetches the HTTP header for u. First HEAD is used, and +// if that's not allowed (http.StatusMethodNotAllowed), then GET is used. +func fetchHTTPHeader(ctx context.Context, u string) (header http.Header, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, u, nil) + if err != nil { + return nil, errz.Err(err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, errz.Err(err) + } + if resp.Body != nil { + _ = resp.Body.Close() + } + + switch resp.StatusCode { + case http.StatusOK: + return resp.Header, nil + default: + return nil, errz.Errorf("unexpected HTTP status (%s) for HEAD: %s", resp.Status, u) + case http.StatusMethodNotAllowed: + } + + // HEAD not allowed, try GET + req, err = http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return nil, errz.Err(err) + } + + resp, err = http.DefaultClient.Do(req) + if err != nil { + return nil, errz.Err(err) + } + if resp.Body != nil { + _ = resp.Body.Close() + } + + if resp.StatusCode != http.StatusOK { + return nil, errz.Errorf("unexpected HTTP status (%s) for GET: %s", resp.Status, u) + } + + return resp.Header, nil +} + +func getRemoteChecksum(ctx context.Context, u string) (string, error) { + return "", errz.New("not implemented") +} + // fetch ensures that loc exists locally as a file. This may // entail downloading the file via HTTPS etc. func (fs *Files) fetch(ctx context.Context, loc string) (fpath string, err error) { diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go new file mode 100644 index 000000000..6289da1cb --- /dev/null +++ b/libsq/source/download_test.go @@ -0,0 +1,64 @@ +package source + +import ( + "context" + "github.com/neilotoole/sq/testh/proj" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "net/http" + "net/http/httptest" + "strconv" + "testing" +) + +func TestGetRemoteChecksum(t *testing.T) { + // sq add https://sq.io/testdata/actor.csv + // + // content-length: 7641 + // date: Thu, 07 Dec 2023 06:31:10 GMT + // etag: "069dbf690a12d5eb853feb8e04aeb49e-ssl" + + // TODO +} + +func TestFetchHTTPHeader_sqio(t *testing.T) { + u := "https://sq.io/testdata/actor.csv" + + header, err := fetchHTTPHeader(context.Background(), u) + require.NoError(t, err) + require.NotNil(t, header) + + // TODO +} + +func TestFetchHTTPHeader_HEAD_fallback_GET(t *testing.T) { + b := proj.ReadFile("drivers/csv/testdata/sakila-csv/actor.csv") + srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + w.Header().Set(http.CanonicalHeaderKey("Content-Length"), strconv.Itoa(len(b))) + w.WriteHeader(http.StatusOK) + _, err := w.Write(b) + require.NoError(t, err) + + })) + t.Cleanup(srvr.Close) + + u := srvr.URL + + header, err := fetchHTTPHeader(context.Background(), u) + assert.NoError(t, err) + assert.NotNil(t, header) + + //u := "https://sq.io/testdata/actor.csv" + // + //header, allowed, err := fetchHTTPHeader(context.Background(), u) + //require.NoError(t, err) + //require.True(t, allowed) + //require.NotNil(t, header) + // + //// TODO +} From 68ca1fd0dd0e800d39f033663c82377c14454df4 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 7 Dec 2023 09:10:40 -0700 Subject: [PATCH 066/195] wip: downloader --- libsq/core/lg/lga/lga.go | 1 + libsq/core/lg/lgm/lgm.go | 25 ++- libsq/source/download.go | 133 ++++++++++-- libsq/source/download_test.go | 29 ++- libsq/source/testdata/downloader.go | 1 + .../downloader/cache-dir-1/download/actor.csv | 201 ++++++++++++++++++ 6 files changed, 364 insertions(+), 26 deletions(-) create mode 100644 libsq/source/testdata/downloader.go create mode 100755 libsq/source/testdata/downloader/cache-dir-1/download/actor.csv diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index 34fa638e4..4ca3a3d74 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -58,5 +58,6 @@ const ( Via = "via" Version = "version" Timestamp = "timestamp" + Written = "written" Text = "text" ) diff --git a/libsq/core/lg/lgm/lgm.go b/libsq/core/lg/lgm/lgm.go index 9f14a2d92..968dd8c49 100644 --- a/libsq/core/lg/lgm/lgm.go +++ b/libsq/core/lg/lgm/lgm.go @@ -4,15 +4,18 @@ package lgm const ( - CloseDB = "Close DB" - CloseConn = "Close SQL connection" - CloseDBRows = "Close DB rows" - CloseDBStmt = "Close DB stmt" - CloseFileReader = "Close file reader" - CtxDone = "Context unexpectedly done" - OpenSrc = "Open source" - ReadDBRows = "Read DB rows" - RowsAffected = "Rows affected" - TxRollback = "Rollback DB tx" - Unexpected = "Unexpected" + CloseDB = "Close DB" + CloseConn = "Close SQL connection" + CloseDBRows = "Close DB rows" + CloseDBStmt = "Close DB stmt" + CloseHTTPResponseBody = "Close HTTP response body" + CloseFileReader = "Close file reader" + CloseFileWriter = "Close file writer" + CtxDone = "Context unexpectedly done" + OpenSrc = "Open source" + ReadDBRows = "Read DB rows" + RemoveFile = "Remove file" + RowsAffected = "Rows affected" + TxRollback = "Rollback DB tx" + Unexpected = "Unexpected" ) diff --git a/libsq/source/download.go b/libsq/source/download.go index f422dc2f2..ab94f3051 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -2,23 +2,30 @@ package source import ( "context" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" "golang.org/x/exp/maps" + "io" "log/slog" + "mime" "net/http" "net/url" "os" + "path" "path/filepath" + "sync" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/source/fetcher" ) -func newDownloader(log *slog.Logger, srcCacheDir, url string) *downloader { +func newDownloader(srcCacheDir, url string) *downloader { downloadDir := filepath.Join(srcCacheDir, "download") return &downloader{ - log: log.With(lga.URL, url, "download_dir", downloadDir), srcCacheDir: srcCacheDir, downloadDir: downloadDir, checksumFile: filepath.Join(srcCacheDir, "download.checksum.txt"), @@ -27,45 +34,124 @@ func newDownloader(log *slog.Logger, srcCacheDir, url string) *downloader { } type downloader struct { - log *slog.Logger + mu sync.Mutex srcCacheDir string downloadDir string checksumFile string url string } -func (d *downloader) Cached() (ok bool, sum checksum.Checksum, fp string) { +func (d *downloader) log(log *slog.Logger) *slog.Logger { + return log.With(lga.URL, d.url, "download_dir", d.downloadDir) +} + +// Download downloads the file at the URL to the download dir, and also writes +// the file to dest, and returns the file path of the downloaded file. +// It is the caller's responsibility to close dest. If an appropriate file name +// cannot be determined from the HTTP response, the file is named "download". +// If the download fails at any stage, the download file is removed, but written +// always returns the number of bytes written to dest. +// Note that the download process is context-aware. +func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int64, fp string, err error) { + d.mu.Lock() + defer d.mu.Unlock() + + log := d.log(lg.FromContext(ctx)) + + if err = ioz.RequireDir(d.downloadDir); err != nil { + return written, "", errz.Wrapf(err, "could not create download dir for: %s", d.url) + } + + var cancelFn context.CancelFunc + ctx, cancelFn = context.WithCancel(ctx) + defer cancelFn() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil) + if err != nil { + return written, "", errz.Wrapf(err, "download new request failed for: %s", d.url) + } + + // FIXME: Use a client that doesn't require SSL (see fetcher) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return written, "", errz.Wrapf(err, "download failed for: %s", d.url) + } + defer func() { + if resp != nil && resp.Body != nil { + lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) + } + }() + + if resp.StatusCode != http.StatusOK { + return written, "", errz.Errorf("download failed with %s for %s", resp.Status, d.url) + } + + filename := getDownloadFilename(resp) + if filename == "" { + filename = "download" + } + + fp = filepath.Join(d.downloadDir, filename) + f, err := os.OpenFile(fp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) + if err != nil { + return written, "", errz.Wrapf(err, "could not create download file for: %s", d.url) + } + + written, err = io.Copy( + contextio.NewWriter(ctx, io.MultiWriter(f, dest)), + contextio.NewReader(ctx, resp.Body), + ) + if err != nil { + log.Error("failed to write download file", lga.File, fp, lga.URL, d.url, lga.Err, err) + lg.WarnIfCloseError(log, lgm.CloseFileWriter, f) + lg.WarnIfFuncError(log, lgm.RemoveFile, func() error { return errz.Err(os.Remove(fp)) }) + return written, "", err + } + + if err = f.Close(); err != nil { + lg.WarnIfFuncError(log, lgm.RemoveFile, func() error { return errz.Err(os.Remove(fp)) }) + return written, "", errz.Wrapf(err, "failed to close download file: %s", fp) + } + + log.Info("Wrote download file", lga.Written, written, lga.File, fp) + return written, fp, nil +} +func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum, fp string) { + d.mu.Lock() + defer d.mu.Unlock() + + log := d.log(lg.FromContext(ctx)) + fi, err := os.Stat(d.downloadDir) if err != nil { - d.log.Debug("not cached: can't stat download dir") + log.Debug("not cached: can't stat download dir") return false, "", "" } if !fi.IsDir() { - d.log.Error("not cached: download dir is not a dir") + log.Error("not cached: download dir is not a dir") return false, "", "" } fi, err = os.Stat(d.checksumFile) if err != nil { - d.log.Debug("not cached: can't stat download checksum file") + log.Debug("not cached: can't stat download checksum file") return false, "", "" } checksums, err := checksum.ReadFile(d.checksumFile) if err != nil { - d.log.Debug("not cached: can't read download checksum file") + log.Debug("not cached: can't read download checksum file") return false, "", "" } if len(checksums) != 1 { - d.log.Debug("not cached: download checksum file has unexpected number of entries") + log.Debug("not cached: download checksum file has unexpected number of entries") return false, "", "" } key := maps.Keys(checksums)[0] sum = checksums[key] if len(sum) == 0 { - d.log.Debug("not cached: checksum file has empty checksum", lga.File, key) + log.Debug("not cached: checksum file has empty checksum", lga.File, key) return false, "", "" } @@ -73,11 +159,11 @@ func (d *downloader) Cached() (ok bool, sum checksum.Checksum, fp string) { fi, err = os.Stat(downloadFile) if err != nil { - d.log.Debug("not cached: can't stat download file referenced in checksum file", lga.File, key) + log.Debug("not cached: can't stat download file referenced in checksum file", lga.File, key) return false, "", "" } - d.log.Info("found cached file", lga.File, key) + log.Info("found cached file", lga.File, key) return true, sum, downloadFile } @@ -106,6 +192,9 @@ func fetchHTTPHeader(ctx context.Context, u string) (header http.Header, err err } // HEAD not allowed, try GET + var cancelFn context.CancelFunc + ctx, cancelFn = context.WithCancel(ctx) + defer cancelFn() req, err = http.NewRequestWithContext(ctx, http.MethodGet, u, nil) if err != nil { return nil, errz.Err(err) @@ -165,3 +254,23 @@ func (fs *Files) fetch(ctx context.Context, loc string) (fpath string, err error return dlFile.Name(), nil } + +// getDownloadFilename returns the filename to use for a download. +// It first checks the Content-Disposition header, and if that's +// not present, it uses the last path segment of the URL. +// It's possible that the returned value will be empty string; the +// caller should handle that situation themselves. +func getDownloadFilename(resp *http.Response) string { + var filename string + dispHeader := resp.Header.Get("Content-Disposition") + if dispHeader != "" { + if _, params, err := mime.ParseMediaType(dispHeader); err == nil { + filename = params["filename"] + } + } + if filename == "" { + filename = path.Base(resp.Request.URL.Path) + } + + return filename +} diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go index 6289da1cb..845f41310 100644 --- a/libsq/source/download_test.go +++ b/libsq/source/download_test.go @@ -1,12 +1,16 @@ package source import ( + "bytes" "context" + "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/testh/proj" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "net/http" "net/http/httptest" + "path/filepath" "strconv" "testing" ) @@ -21,16 +25,35 @@ func TestGetRemoteChecksum(t *testing.T) { // TODO } -func TestFetchHTTPHeader_sqio(t *testing.T) { - u := "https://sq.io/testdata/actor.csv" +const ( + urlActorCSV = "https://sq.io/testdata/actor.csv" + sizeActorCSV = int64(7641) +) - header, err := fetchHTTPHeader(context.Background(), u) +func TestFetchHTTPHeader_sqio(t *testing.T) { + header, err := fetchHTTPHeader(context.Background(), urlActorCSV) require.NoError(t, err) require.NotNil(t, header) // TODO } +func TestDownloader(t *testing.T) { + const u = "https://sq.io/testdata/actor.csv" + ctx := lg.NewContext(context.Background(), slogt.New(t)) + + cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-1")) + require.NoError(t, err) + dl := newDownloader(cacheDir, urlActorCSV) + buf := &bytes.Buffer{} + written, fp, err := dl.Download(ctx, buf) + require.NoError(t, err) + require.FileExists(t, fp) + require.Equal(t, sizeActorCSV, written) + + // TODO +} + func TestFetchHTTPHeader_HEAD_fallback_GET(t *testing.T) { b := proj.ReadFile("drivers/csv/testdata/sakila-csv/actor.csv") srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/libsq/source/testdata/downloader.go b/libsq/source/testdata/downloader.go new file mode 100644 index 000000000..69d29d3c6 --- /dev/null +++ b/libsq/source/testdata/downloader.go @@ -0,0 +1 @@ +package testdata diff --git a/libsq/source/testdata/downloader/cache-dir-1/download/actor.csv b/libsq/source/testdata/downloader/cache-dir-1/download/actor.csv new file mode 100755 index 000000000..66e2629a2 --- /dev/null +++ b/libsq/source/testdata/downloader/cache-dir-1/download/actor.csv @@ -0,0 +1,201 @@ +actor_id,first_name,last_name,last_update +1,PENELOPE,GUINESS,2020-02-15T06:59:28Z +2,NICK,WAHLBERG,2020-02-15T06:59:28Z +3,ED,CHASE,2020-02-15T06:59:28Z +4,JENNIFER,DAVIS,2020-02-15T06:59:28Z +5,JOHNNY,LOLLOBRIGIDA,2020-02-15T06:59:28Z +6,BETTE,NICHOLSON,2020-02-15T06:59:28Z +7,GRACE,MOSTEL,2020-02-15T06:59:28Z +8,MATTHEW,JOHANSSON,2020-02-15T06:59:28Z +9,JOE,SWANK,2020-02-15T06:59:28Z +10,CHRISTIAN,GABLE,2020-02-15T06:59:28Z +11,ZERO,CAGE,2020-02-15T06:59:28Z +12,KARL,BERRY,2020-02-15T06:59:28Z +13,UMA,WOOD,2020-02-15T06:59:28Z +14,VIVIEN,BERGEN,2020-02-15T06:59:28Z +15,CUBA,OLIVIER,2020-02-15T06:59:28Z +16,FRED,COSTNER,2020-02-15T06:59:28Z +17,HELEN,VOIGHT,2020-02-15T06:59:28Z +18,DAN,TORN,2020-02-15T06:59:28Z +19,BOB,FAWCETT,2020-02-15T06:59:28Z +20,LUCILLE,TRACY,2020-02-15T06:59:28Z +21,KIRSTEN,PALTROW,2020-02-15T06:59:28Z +22,ELVIS,MARX,2020-02-15T06:59:28Z +23,SANDRA,KILMER,2020-02-15T06:59:28Z +24,CAMERON,STREEP,2020-02-15T06:59:28Z +25,KEVIN,BLOOM,2020-02-15T06:59:28Z +26,RIP,CRAWFORD,2020-02-15T06:59:28Z +27,JULIA,MCQUEEN,2020-02-15T06:59:28Z +28,WOODY,HOFFMAN,2020-02-15T06:59:28Z +29,ALEC,WAYNE,2020-02-15T06:59:28Z +30,SANDRA,PECK,2020-02-15T06:59:28Z +31,SISSY,SOBIESKI,2020-02-15T06:59:28Z +32,TIM,HACKMAN,2020-02-15T06:59:28Z +33,MILLA,PECK,2020-02-15T06:59:28Z +34,AUDREY,OLIVIER,2020-02-15T06:59:28Z +35,JUDY,DEAN,2020-02-15T06:59:28Z +36,BURT,DUKAKIS,2020-02-15T06:59:28Z +37,VAL,BOLGER,2020-02-15T06:59:28Z +38,TOM,MCKELLEN,2020-02-15T06:59:28Z +39,GOLDIE,BRODY,2020-02-15T06:59:28Z +40,JOHNNY,CAGE,2020-02-15T06:59:28Z +41,JODIE,DEGENERES,2020-02-15T06:59:28Z +42,TOM,MIRANDA,2020-02-15T06:59:28Z +43,KIRK,JOVOVICH,2020-02-15T06:59:28Z +44,NICK,STALLONE,2020-02-15T06:59:28Z +45,REESE,KILMER,2020-02-15T06:59:28Z +46,PARKER,GOLDBERG,2020-02-15T06:59:28Z +47,JULIA,BARRYMORE,2020-02-15T06:59:28Z +48,FRANCES,DAY-LEWIS,2020-02-15T06:59:28Z +49,ANNE,CRONYN,2020-02-15T06:59:28Z +50,NATALIE,HOPKINS,2020-02-15T06:59:28Z +51,GARY,PHOENIX,2020-02-15T06:59:28Z +52,CARMEN,HUNT,2020-02-15T06:59:28Z +53,MENA,TEMPLE,2020-02-15T06:59:28Z +54,PENELOPE,PINKETT,2020-02-15T06:59:28Z +55,FAY,KILMER,2020-02-15T06:59:28Z +56,DAN,HARRIS,2020-02-15T06:59:28Z +57,JUDE,CRUISE,2020-02-15T06:59:28Z +58,CHRISTIAN,AKROYD,2020-02-15T06:59:28Z +59,DUSTIN,TAUTOU,2020-02-15T06:59:28Z +60,HENRY,BERRY,2020-02-15T06:59:28Z +61,CHRISTIAN,NEESON,2020-02-15T06:59:28Z +62,JAYNE,NEESON,2020-02-15T06:59:28Z +63,CAMERON,WRAY,2020-02-15T06:59:28Z +64,RAY,JOHANSSON,2020-02-15T06:59:28Z +65,ANGELA,HUDSON,2020-02-15T06:59:28Z +66,MARY,TANDY,2020-02-15T06:59:28Z +67,JESSICA,BAILEY,2020-02-15T06:59:28Z +68,RIP,WINSLET,2020-02-15T06:59:28Z +69,KENNETH,PALTROW,2020-02-15T06:59:28Z +70,MICHELLE,MCCONAUGHEY,2020-02-15T06:59:28Z +71,ADAM,GRANT,2020-02-15T06:59:28Z +72,SEAN,WILLIAMS,2020-02-15T06:59:28Z +73,GARY,PENN,2020-02-15T06:59:28Z +74,MILLA,KEITEL,2020-02-15T06:59:28Z +75,BURT,POSEY,2020-02-15T06:59:28Z +76,ANGELINA,ASTAIRE,2020-02-15T06:59:28Z +77,CARY,MCCONAUGHEY,2020-02-15T06:59:28Z +78,GROUCHO,SINATRA,2020-02-15T06:59:28Z +79,MAE,HOFFMAN,2020-02-15T06:59:28Z +80,RALPH,CRUZ,2020-02-15T06:59:28Z +81,SCARLETT,DAMON,2020-02-15T06:59:28Z +82,WOODY,JOLIE,2020-02-15T06:59:28Z +83,BEN,WILLIS,2020-02-15T06:59:28Z +84,JAMES,PITT,2020-02-15T06:59:28Z +85,MINNIE,ZELLWEGER,2020-02-15T06:59:28Z +86,GREG,CHAPLIN,2020-02-15T06:59:28Z +87,SPENCER,PECK,2020-02-15T06:59:28Z +88,KENNETH,PESCI,2020-02-15T06:59:28Z +89,CHARLIZE,DENCH,2020-02-15T06:59:28Z +90,SEAN,GUINESS,2020-02-15T06:59:28Z +91,CHRISTOPHER,BERRY,2020-02-15T06:59:28Z +92,KIRSTEN,AKROYD,2020-02-15T06:59:28Z +93,ELLEN,PRESLEY,2020-02-15T06:59:28Z +94,KENNETH,TORN,2020-02-15T06:59:28Z +95,DARYL,WAHLBERG,2020-02-15T06:59:28Z +96,GENE,WILLIS,2020-02-15T06:59:28Z +97,MEG,HAWKE,2020-02-15T06:59:28Z +98,CHRIS,BRIDGES,2020-02-15T06:59:28Z +99,JIM,MOSTEL,2020-02-15T06:59:28Z +100,SPENCER,DEPP,2020-02-15T06:59:28Z +101,SUSAN,DAVIS,2020-02-15T06:59:28Z +102,WALTER,TORN,2020-02-15T06:59:28Z +103,MATTHEW,LEIGH,2020-02-15T06:59:28Z +104,PENELOPE,CRONYN,2020-02-15T06:59:28Z +105,SIDNEY,CROWE,2020-02-15T06:59:28Z +106,GROUCHO,DUNST,2020-02-15T06:59:28Z +107,GINA,DEGENERES,2020-02-15T06:59:28Z +108,WARREN,NOLTE,2020-02-15T06:59:28Z +109,SYLVESTER,DERN,2020-02-15T06:59:28Z +110,SUSAN,DAVIS,2020-02-15T06:59:28Z +111,CAMERON,ZELLWEGER,2020-02-15T06:59:28Z +112,RUSSELL,BACALL,2020-02-15T06:59:28Z +113,MORGAN,HOPKINS,2020-02-15T06:59:28Z +114,MORGAN,MCDORMAND,2020-02-15T06:59:28Z +115,HARRISON,BALE,2020-02-15T06:59:28Z +116,DAN,STREEP,2020-02-15T06:59:28Z +117,RENEE,TRACY,2020-02-15T06:59:28Z +118,CUBA,ALLEN,2020-02-15T06:59:28Z +119,WARREN,JACKMAN,2020-02-15T06:59:28Z +120,PENELOPE,MONROE,2020-02-15T06:59:28Z +121,LIZA,BERGMAN,2020-02-15T06:59:28Z +122,SALMA,NOLTE,2020-02-15T06:59:28Z +123,JULIANNE,DENCH,2020-02-15T06:59:28Z +124,SCARLETT,BENING,2020-02-15T06:59:28Z +125,ALBERT,NOLTE,2020-02-15T06:59:28Z +126,FRANCES,TOMEI,2020-02-15T06:59:28Z +127,KEVIN,GARLAND,2020-02-15T06:59:28Z +128,CATE,MCQUEEN,2020-02-15T06:59:28Z +129,DARYL,CRAWFORD,2020-02-15T06:59:28Z +130,GRETA,KEITEL,2020-02-15T06:59:28Z +131,JANE,JACKMAN,2020-02-15T06:59:28Z +132,ADAM,HOPPER,2020-02-15T06:59:28Z +133,RICHARD,PENN,2020-02-15T06:59:28Z +134,GENE,HOPKINS,2020-02-15T06:59:28Z +135,RITA,REYNOLDS,2020-02-15T06:59:28Z +136,ED,MANSFIELD,2020-02-15T06:59:28Z +137,MORGAN,WILLIAMS,2020-02-15T06:59:28Z +138,LUCILLE,DEE,2020-02-15T06:59:28Z +139,EWAN,GOODING,2020-02-15T06:59:28Z +140,WHOOPI,HURT,2020-02-15T06:59:28Z +141,CATE,HARRIS,2020-02-15T06:59:28Z +142,JADA,RYDER,2020-02-15T06:59:28Z +143,RIVER,DEAN,2020-02-15T06:59:28Z +144,ANGELA,WITHERSPOON,2020-02-15T06:59:28Z +145,KIM,ALLEN,2020-02-15T06:59:28Z +146,ALBERT,JOHANSSON,2020-02-15T06:59:28Z +147,FAY,WINSLET,2020-02-15T06:59:28Z +148,EMILY,DEE,2020-02-15T06:59:28Z +149,RUSSELL,TEMPLE,2020-02-15T06:59:28Z +150,JAYNE,NOLTE,2020-02-15T06:59:28Z +151,GEOFFREY,HESTON,2020-02-15T06:59:28Z +152,BEN,HARRIS,2020-02-15T06:59:28Z +153,MINNIE,KILMER,2020-02-15T06:59:28Z +154,MERYL,GIBSON,2020-02-15T06:59:28Z +155,IAN,TANDY,2020-02-15T06:59:28Z +156,FAY,WOOD,2020-02-15T06:59:28Z +157,GRETA,MALDEN,2020-02-15T06:59:28Z +158,VIVIEN,BASINGER,2020-02-15T06:59:28Z +159,LAURA,BRODY,2020-02-15T06:59:28Z +160,CHRIS,DEPP,2020-02-15T06:59:28Z +161,HARVEY,HOPE,2020-02-15T06:59:28Z +162,OPRAH,KILMER,2020-02-15T06:59:28Z +163,CHRISTOPHER,WEST,2020-02-15T06:59:28Z +164,HUMPHREY,WILLIS,2020-02-15T06:59:28Z +165,AL,GARLAND,2020-02-15T06:59:28Z +166,NICK,DEGENERES,2020-02-15T06:59:28Z +167,LAURENCE,BULLOCK,2020-02-15T06:59:28Z +168,WILL,WILSON,2020-02-15T06:59:28Z +169,KENNETH,HOFFMAN,2020-02-15T06:59:28Z +170,MENA,HOPPER,2020-02-15T06:59:28Z +171,OLYMPIA,PFEIFFER,2020-02-15T06:59:28Z +172,GROUCHO,WILLIAMS,2020-02-15T06:59:28Z +173,ALAN,DREYFUSS,2020-02-15T06:59:28Z +174,MICHAEL,BENING,2020-02-15T06:59:28Z +175,WILLIAM,HACKMAN,2020-02-15T06:59:28Z +176,JON,CHASE,2020-02-15T06:59:28Z +177,GENE,MCKELLEN,2020-02-15T06:59:28Z +178,LISA,MONROE,2020-02-15T06:59:28Z +179,ED,GUINESS,2020-02-15T06:59:28Z +180,JEFF,SILVERSTONE,2020-02-15T06:59:28Z +181,MATTHEW,CARREY,2020-02-15T06:59:28Z +182,DEBBIE,AKROYD,2020-02-15T06:59:28Z +183,RUSSELL,CLOSE,2020-02-15T06:59:28Z +184,HUMPHREY,GARLAND,2020-02-15T06:59:28Z +185,MICHAEL,BOLGER,2020-02-15T06:59:28Z +186,JULIA,ZELLWEGER,2020-02-15T06:59:28Z +187,RENEE,BALL,2020-02-15T06:59:28Z +188,ROCK,DUKAKIS,2020-02-15T06:59:28Z +189,CUBA,BIRCH,2020-02-15T06:59:28Z +190,AUDREY,BAILEY,2020-02-15T06:59:28Z +191,GREGORY,GOODING,2020-02-15T06:59:28Z +192,JOHN,SUVARI,2020-02-15T06:59:28Z +193,BURT,TEMPLE,2020-02-15T06:59:28Z +194,MERYL,ALLEN,2020-02-15T06:59:28Z +195,JAYNE,SILVERSTONE,2020-02-15T06:59:28Z +196,BELA,WALKEN,2020-02-15T06:59:28Z +197,REESE,WEST,2020-02-15T06:59:28Z +198,MARY,KEITEL,2020-02-15T06:59:28Z +199,JULIA,FAWCETT,2020-02-15T06:59:28Z +200,THORA,TEMPLE,2020-02-15T06:59:28Z From 83a1ca77956c1314e21259be88d4eec7c10112eb Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 7 Dec 2023 09:12:15 -0700 Subject: [PATCH 067/195] cleanup --- libsq/source/testdata/downloader.go | 1 - libsq/source/testdata/downloader/.gitignore | 1 + .../downloader/cache-dir-1/download/actor.csv | 201 ------------------ 3 files changed, 1 insertion(+), 202 deletions(-) delete mode 100644 libsq/source/testdata/downloader.go create mode 100644 libsq/source/testdata/downloader/.gitignore delete mode 100755 libsq/source/testdata/downloader/cache-dir-1/download/actor.csv diff --git a/libsq/source/testdata/downloader.go b/libsq/source/testdata/downloader.go deleted file mode 100644 index 69d29d3c6..000000000 --- a/libsq/source/testdata/downloader.go +++ /dev/null @@ -1 +0,0 @@ -package testdata diff --git a/libsq/source/testdata/downloader/.gitignore b/libsq/source/testdata/downloader/.gitignore new file mode 100644 index 000000000..6e1cc9c27 --- /dev/null +++ b/libsq/source/testdata/downloader/.gitignore @@ -0,0 +1 @@ +cache-dir-* diff --git a/libsq/source/testdata/downloader/cache-dir-1/download/actor.csv b/libsq/source/testdata/downloader/cache-dir-1/download/actor.csv deleted file mode 100755 index 66e2629a2..000000000 --- a/libsq/source/testdata/downloader/cache-dir-1/download/actor.csv +++ /dev/null @@ -1,201 +0,0 @@ -actor_id,first_name,last_name,last_update -1,PENELOPE,GUINESS,2020-02-15T06:59:28Z -2,NICK,WAHLBERG,2020-02-15T06:59:28Z -3,ED,CHASE,2020-02-15T06:59:28Z -4,JENNIFER,DAVIS,2020-02-15T06:59:28Z -5,JOHNNY,LOLLOBRIGIDA,2020-02-15T06:59:28Z -6,BETTE,NICHOLSON,2020-02-15T06:59:28Z -7,GRACE,MOSTEL,2020-02-15T06:59:28Z -8,MATTHEW,JOHANSSON,2020-02-15T06:59:28Z -9,JOE,SWANK,2020-02-15T06:59:28Z -10,CHRISTIAN,GABLE,2020-02-15T06:59:28Z -11,ZERO,CAGE,2020-02-15T06:59:28Z -12,KARL,BERRY,2020-02-15T06:59:28Z -13,UMA,WOOD,2020-02-15T06:59:28Z -14,VIVIEN,BERGEN,2020-02-15T06:59:28Z -15,CUBA,OLIVIER,2020-02-15T06:59:28Z -16,FRED,COSTNER,2020-02-15T06:59:28Z -17,HELEN,VOIGHT,2020-02-15T06:59:28Z -18,DAN,TORN,2020-02-15T06:59:28Z -19,BOB,FAWCETT,2020-02-15T06:59:28Z -20,LUCILLE,TRACY,2020-02-15T06:59:28Z -21,KIRSTEN,PALTROW,2020-02-15T06:59:28Z -22,ELVIS,MARX,2020-02-15T06:59:28Z -23,SANDRA,KILMER,2020-02-15T06:59:28Z -24,CAMERON,STREEP,2020-02-15T06:59:28Z -25,KEVIN,BLOOM,2020-02-15T06:59:28Z -26,RIP,CRAWFORD,2020-02-15T06:59:28Z -27,JULIA,MCQUEEN,2020-02-15T06:59:28Z -28,WOODY,HOFFMAN,2020-02-15T06:59:28Z -29,ALEC,WAYNE,2020-02-15T06:59:28Z -30,SANDRA,PECK,2020-02-15T06:59:28Z -31,SISSY,SOBIESKI,2020-02-15T06:59:28Z -32,TIM,HACKMAN,2020-02-15T06:59:28Z -33,MILLA,PECK,2020-02-15T06:59:28Z -34,AUDREY,OLIVIER,2020-02-15T06:59:28Z -35,JUDY,DEAN,2020-02-15T06:59:28Z -36,BURT,DUKAKIS,2020-02-15T06:59:28Z -37,VAL,BOLGER,2020-02-15T06:59:28Z -38,TOM,MCKELLEN,2020-02-15T06:59:28Z -39,GOLDIE,BRODY,2020-02-15T06:59:28Z -40,JOHNNY,CAGE,2020-02-15T06:59:28Z -41,JODIE,DEGENERES,2020-02-15T06:59:28Z -42,TOM,MIRANDA,2020-02-15T06:59:28Z -43,KIRK,JOVOVICH,2020-02-15T06:59:28Z -44,NICK,STALLONE,2020-02-15T06:59:28Z -45,REESE,KILMER,2020-02-15T06:59:28Z -46,PARKER,GOLDBERG,2020-02-15T06:59:28Z -47,JULIA,BARRYMORE,2020-02-15T06:59:28Z -48,FRANCES,DAY-LEWIS,2020-02-15T06:59:28Z -49,ANNE,CRONYN,2020-02-15T06:59:28Z -50,NATALIE,HOPKINS,2020-02-15T06:59:28Z -51,GARY,PHOENIX,2020-02-15T06:59:28Z -52,CARMEN,HUNT,2020-02-15T06:59:28Z -53,MENA,TEMPLE,2020-02-15T06:59:28Z -54,PENELOPE,PINKETT,2020-02-15T06:59:28Z -55,FAY,KILMER,2020-02-15T06:59:28Z -56,DAN,HARRIS,2020-02-15T06:59:28Z -57,JUDE,CRUISE,2020-02-15T06:59:28Z -58,CHRISTIAN,AKROYD,2020-02-15T06:59:28Z -59,DUSTIN,TAUTOU,2020-02-15T06:59:28Z -60,HENRY,BERRY,2020-02-15T06:59:28Z -61,CHRISTIAN,NEESON,2020-02-15T06:59:28Z -62,JAYNE,NEESON,2020-02-15T06:59:28Z -63,CAMERON,WRAY,2020-02-15T06:59:28Z -64,RAY,JOHANSSON,2020-02-15T06:59:28Z -65,ANGELA,HUDSON,2020-02-15T06:59:28Z -66,MARY,TANDY,2020-02-15T06:59:28Z -67,JESSICA,BAILEY,2020-02-15T06:59:28Z -68,RIP,WINSLET,2020-02-15T06:59:28Z -69,KENNETH,PALTROW,2020-02-15T06:59:28Z -70,MICHELLE,MCCONAUGHEY,2020-02-15T06:59:28Z -71,ADAM,GRANT,2020-02-15T06:59:28Z -72,SEAN,WILLIAMS,2020-02-15T06:59:28Z -73,GARY,PENN,2020-02-15T06:59:28Z -74,MILLA,KEITEL,2020-02-15T06:59:28Z -75,BURT,POSEY,2020-02-15T06:59:28Z -76,ANGELINA,ASTAIRE,2020-02-15T06:59:28Z -77,CARY,MCCONAUGHEY,2020-02-15T06:59:28Z -78,GROUCHO,SINATRA,2020-02-15T06:59:28Z -79,MAE,HOFFMAN,2020-02-15T06:59:28Z -80,RALPH,CRUZ,2020-02-15T06:59:28Z -81,SCARLETT,DAMON,2020-02-15T06:59:28Z -82,WOODY,JOLIE,2020-02-15T06:59:28Z -83,BEN,WILLIS,2020-02-15T06:59:28Z -84,JAMES,PITT,2020-02-15T06:59:28Z -85,MINNIE,ZELLWEGER,2020-02-15T06:59:28Z -86,GREG,CHAPLIN,2020-02-15T06:59:28Z -87,SPENCER,PECK,2020-02-15T06:59:28Z -88,KENNETH,PESCI,2020-02-15T06:59:28Z -89,CHARLIZE,DENCH,2020-02-15T06:59:28Z -90,SEAN,GUINESS,2020-02-15T06:59:28Z -91,CHRISTOPHER,BERRY,2020-02-15T06:59:28Z -92,KIRSTEN,AKROYD,2020-02-15T06:59:28Z -93,ELLEN,PRESLEY,2020-02-15T06:59:28Z -94,KENNETH,TORN,2020-02-15T06:59:28Z -95,DARYL,WAHLBERG,2020-02-15T06:59:28Z -96,GENE,WILLIS,2020-02-15T06:59:28Z -97,MEG,HAWKE,2020-02-15T06:59:28Z -98,CHRIS,BRIDGES,2020-02-15T06:59:28Z -99,JIM,MOSTEL,2020-02-15T06:59:28Z -100,SPENCER,DEPP,2020-02-15T06:59:28Z -101,SUSAN,DAVIS,2020-02-15T06:59:28Z -102,WALTER,TORN,2020-02-15T06:59:28Z -103,MATTHEW,LEIGH,2020-02-15T06:59:28Z -104,PENELOPE,CRONYN,2020-02-15T06:59:28Z -105,SIDNEY,CROWE,2020-02-15T06:59:28Z -106,GROUCHO,DUNST,2020-02-15T06:59:28Z -107,GINA,DEGENERES,2020-02-15T06:59:28Z -108,WARREN,NOLTE,2020-02-15T06:59:28Z -109,SYLVESTER,DERN,2020-02-15T06:59:28Z -110,SUSAN,DAVIS,2020-02-15T06:59:28Z -111,CAMERON,ZELLWEGER,2020-02-15T06:59:28Z -112,RUSSELL,BACALL,2020-02-15T06:59:28Z -113,MORGAN,HOPKINS,2020-02-15T06:59:28Z -114,MORGAN,MCDORMAND,2020-02-15T06:59:28Z -115,HARRISON,BALE,2020-02-15T06:59:28Z -116,DAN,STREEP,2020-02-15T06:59:28Z -117,RENEE,TRACY,2020-02-15T06:59:28Z -118,CUBA,ALLEN,2020-02-15T06:59:28Z -119,WARREN,JACKMAN,2020-02-15T06:59:28Z -120,PENELOPE,MONROE,2020-02-15T06:59:28Z -121,LIZA,BERGMAN,2020-02-15T06:59:28Z -122,SALMA,NOLTE,2020-02-15T06:59:28Z -123,JULIANNE,DENCH,2020-02-15T06:59:28Z -124,SCARLETT,BENING,2020-02-15T06:59:28Z -125,ALBERT,NOLTE,2020-02-15T06:59:28Z -126,FRANCES,TOMEI,2020-02-15T06:59:28Z -127,KEVIN,GARLAND,2020-02-15T06:59:28Z -128,CATE,MCQUEEN,2020-02-15T06:59:28Z -129,DARYL,CRAWFORD,2020-02-15T06:59:28Z -130,GRETA,KEITEL,2020-02-15T06:59:28Z -131,JANE,JACKMAN,2020-02-15T06:59:28Z -132,ADAM,HOPPER,2020-02-15T06:59:28Z -133,RICHARD,PENN,2020-02-15T06:59:28Z -134,GENE,HOPKINS,2020-02-15T06:59:28Z -135,RITA,REYNOLDS,2020-02-15T06:59:28Z -136,ED,MANSFIELD,2020-02-15T06:59:28Z -137,MORGAN,WILLIAMS,2020-02-15T06:59:28Z -138,LUCILLE,DEE,2020-02-15T06:59:28Z -139,EWAN,GOODING,2020-02-15T06:59:28Z -140,WHOOPI,HURT,2020-02-15T06:59:28Z -141,CATE,HARRIS,2020-02-15T06:59:28Z -142,JADA,RYDER,2020-02-15T06:59:28Z -143,RIVER,DEAN,2020-02-15T06:59:28Z -144,ANGELA,WITHERSPOON,2020-02-15T06:59:28Z -145,KIM,ALLEN,2020-02-15T06:59:28Z -146,ALBERT,JOHANSSON,2020-02-15T06:59:28Z -147,FAY,WINSLET,2020-02-15T06:59:28Z -148,EMILY,DEE,2020-02-15T06:59:28Z -149,RUSSELL,TEMPLE,2020-02-15T06:59:28Z -150,JAYNE,NOLTE,2020-02-15T06:59:28Z -151,GEOFFREY,HESTON,2020-02-15T06:59:28Z -152,BEN,HARRIS,2020-02-15T06:59:28Z -153,MINNIE,KILMER,2020-02-15T06:59:28Z -154,MERYL,GIBSON,2020-02-15T06:59:28Z -155,IAN,TANDY,2020-02-15T06:59:28Z -156,FAY,WOOD,2020-02-15T06:59:28Z -157,GRETA,MALDEN,2020-02-15T06:59:28Z -158,VIVIEN,BASINGER,2020-02-15T06:59:28Z -159,LAURA,BRODY,2020-02-15T06:59:28Z -160,CHRIS,DEPP,2020-02-15T06:59:28Z -161,HARVEY,HOPE,2020-02-15T06:59:28Z -162,OPRAH,KILMER,2020-02-15T06:59:28Z -163,CHRISTOPHER,WEST,2020-02-15T06:59:28Z -164,HUMPHREY,WILLIS,2020-02-15T06:59:28Z -165,AL,GARLAND,2020-02-15T06:59:28Z -166,NICK,DEGENERES,2020-02-15T06:59:28Z -167,LAURENCE,BULLOCK,2020-02-15T06:59:28Z -168,WILL,WILSON,2020-02-15T06:59:28Z -169,KENNETH,HOFFMAN,2020-02-15T06:59:28Z -170,MENA,HOPPER,2020-02-15T06:59:28Z -171,OLYMPIA,PFEIFFER,2020-02-15T06:59:28Z -172,GROUCHO,WILLIAMS,2020-02-15T06:59:28Z -173,ALAN,DREYFUSS,2020-02-15T06:59:28Z -174,MICHAEL,BENING,2020-02-15T06:59:28Z -175,WILLIAM,HACKMAN,2020-02-15T06:59:28Z -176,JON,CHASE,2020-02-15T06:59:28Z -177,GENE,MCKELLEN,2020-02-15T06:59:28Z -178,LISA,MONROE,2020-02-15T06:59:28Z -179,ED,GUINESS,2020-02-15T06:59:28Z -180,JEFF,SILVERSTONE,2020-02-15T06:59:28Z -181,MATTHEW,CARREY,2020-02-15T06:59:28Z -182,DEBBIE,AKROYD,2020-02-15T06:59:28Z -183,RUSSELL,CLOSE,2020-02-15T06:59:28Z -184,HUMPHREY,GARLAND,2020-02-15T06:59:28Z -185,MICHAEL,BOLGER,2020-02-15T06:59:28Z -186,JULIA,ZELLWEGER,2020-02-15T06:59:28Z -187,RENEE,BALL,2020-02-15T06:59:28Z -188,ROCK,DUKAKIS,2020-02-15T06:59:28Z -189,CUBA,BIRCH,2020-02-15T06:59:28Z -190,AUDREY,BAILEY,2020-02-15T06:59:28Z -191,GREGORY,GOODING,2020-02-15T06:59:28Z -192,JOHN,SUVARI,2020-02-15T06:59:28Z -193,BURT,TEMPLE,2020-02-15T06:59:28Z -194,MERYL,ALLEN,2020-02-15T06:59:28Z -195,JAYNE,SILVERSTONE,2020-02-15T06:59:28Z -196,BELA,WALKEN,2020-02-15T06:59:28Z -197,REESE,WEST,2020-02-15T06:59:28Z -198,MARY,KEITEL,2020-02-15T06:59:28Z -199,JULIA,FAWCETT,2020-02-15T06:59:28Z -200,THORA,TEMPLE,2020-02-15T06:59:28Z From 25946886aadb2358d8a2291032f032f03e348381 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 7 Dec 2023 13:36:49 -0700 Subject: [PATCH 068/195] wip: downloader --- cli/run.go | 6 +- go.mod | 5 +- go.sum | 2 + libsq/core/ioz/checksum/checksum.go | 106 +++++++++++++++----- libsq/core/ioz/ioz.go | 42 ++++++++ libsq/core/stringz/stringz.go | 30 ++++++ libsq/core/stringz/stringz_test.go | 27 ++++++ libsq/source/download.go | 145 +++++++++++++++++++++++----- libsq/source/download_test.go | 18 +++- libsq/source/files.go | 23 ++++- libsq/source/files_test.go | 6 +- libsq/source/internal_test.go | 2 +- testh/testh.go | 2 +- testh/tu/tutil.go | 9 ++ 14 files changed, 362 insertions(+), 61 deletions(-) diff --git a/cli/run.go b/cli/run.go index d77a02446..9d3d3dee7 100644 --- a/cli/run.go +++ b/cli/run.go @@ -2,6 +2,7 @@ package cli import ( "context" + "github.com/neilotoole/sq/libsq/core/ioz" "io" "log/slog" "os" @@ -140,7 +141,10 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { var err error if ru.Files == nil { - ru.Files, err = source.NewFiles(ctx, source.DefaultTempDir(), source.DefaultCacheDir()) + // TODO: The timeout/ssl vals should really come from options. + c := ioz.NewHTTPClient(0, true) + + ru.Files, err = source.NewFiles(ctx, c, source.DefaultTempDir(), source.DefaultCacheDir(), true) if err != nil { lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run) return err diff --git a/go.mod b/go.mod index c05d62f48..504998b45 100644 --- a/go.mod +++ b/go.mod @@ -95,7 +95,10 @@ require github.com/vbauerster/mpb/v8 v8.7.1-0.20231205062852-da3162c67234 // See: https://github.com/djherbis/fscache/pull/21 require github.com/neilotoole/fscache v0.0.0-20231203162946-c9808f16552e -require github.com/djherbis/stream v1.4.0 // indirect +require ( + github.com/djherbis/stream v1.4.0 // indirect + github.com/mrz1836/go-sanitize v1.3.1 // indirect +) // See: https://github.com/djherbis/stream/pull/11 replace github.com/djherbis/stream v1.4.0 => github.com/neilotoole/djherbis-stream v0.0.0-20231203160853-609f47afedda diff --git a/go.sum b/go.sum index 190ecd368..195e2f39d 100644 --- a/go.sum +++ b/go.sum @@ -113,6 +113,8 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= +github.com/mrz1836/go-sanitize v1.3.1 h1:bTxpzDXzGh9cp3XLTeVKgL2iLqEwCaLqqe+3BmpnCbo= +github.com/mrz1836/go-sanitize v1.3.1/go.mod h1:Js6Gq1uiarNReoOeOKxPXxNpKy1FRlbgDDZnJG4THdM= github.com/muesli/mango v0.1.0 h1:DZQK45d2gGbql1arsYA4vfg4d7I9Hfx5rX/GCmzsAvI= github.com/muesli/mango v0.1.0/go.mod h1:5XFpbC8jY5UUv89YQciiXNlbi+iJgt29VDC5xbzrLL4= github.com/muesli/mango-cobra v1.2.0 h1:DQvjzAM0PMZr85Iv9LIMaYISpTOliMEg+uMFtNbYvWg= diff --git a/libsq/core/ioz/checksum/checksum.go b/libsq/core/ioz/checksum/checksum.go index 6c9910fb1..40c7c28e1 100644 --- a/libsq/core/ioz/checksum/checksum.go +++ b/libsq/core/ioz/checksum/checksum.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "fmt" "io" + "net/http" "os" "strconv" "strings" @@ -16,26 +17,6 @@ import ( // Checksum is a checksum of a file. type Checksum string -// ForFile returns a checksum of the file at path. -// The checksum is based on the file's name, size, mode, and -// modification time. File contents are not read. -func ForFile(path string) (Checksum, error) { - fi, err := os.Stat(path) - if err != nil { - return "", errz.Wrap(err, "calculate file checksum") - } - - buf := bytes.Buffer{} - buf.WriteString(fi.Name()) - buf.WriteString(strconv.FormatInt(fi.ModTime().UnixNano(), 10)) - buf.WriteString(strconv.FormatInt(fi.Size(), 10)) - buf.WriteString(strconv.FormatUint(uint64(fi.Mode()), 10)) - buf.WriteString(strconv.FormatBool(fi.IsDir())) - - sum := sha256.Sum256(buf.Bytes()) - return Checksum(fmt.Sprintf("%x", sum)), nil -} - // Write appends a checksum line to w, including // a newline. The typical format is: // @@ -58,8 +39,13 @@ func WriteFile(path string, sum Checksum, name string) error { if err != nil { return errz.Wrap(err, "write checksum file") } - defer func() { _ = f.Close() }() - return Write(f, sum, name) + err = Write(f, sum, name) + if err == nil { + return errz.Err(f.Close()) + } + + _ = f.Close() + return err } // ReadFile reads a checksum file from path. @@ -100,3 +86,79 @@ func Read(r io.Reader) (map[string]Checksum, error) { return sums, errz.Wrap(sc.Err(), "read checksums") } + +// ForFile returns a checksum of the file at path. +// The checksum is based on the file's name, size, mode, and +// modification time. File contents are not read. +func ForFile(path string) (Checksum, error) { + fi, err := os.Stat(path) + if err != nil { + return "", errz.Wrap(err, "calculate file checksum") + } + + buf := bytes.Buffer{} + buf.WriteString(fi.Name()) + buf.WriteString(strconv.FormatInt(fi.ModTime().UnixNano(), 10)) + buf.WriteString(strconv.FormatInt(fi.Size(), 10)) + buf.WriteString(strconv.FormatUint(uint64(fi.Mode()), 10)) + buf.WriteString(strconv.FormatBool(fi.IsDir())) + + sum := sha256.Sum256(buf.Bytes()) + return Checksum(fmt.Sprintf("%x", sum)), nil +} + +// ForHTTPHeader returns a checksum generated from URL u and +// the contents of header. If the header contains an Etag, +// that is used as the primary element. Otherwise, other +// values such as Content-Length and Last-Modified are +// considered. +// +// Deprecated: use ForHTTPResponse instead. +func ForHTTPHeader(u string, header http.Header) Checksum { + buf := bytes.Buffer{} + buf.WriteString(u) + if header != nil { + etag := header.Get("Etag") + if etag != "" { + buf.WriteString(etag) + } else { + buf.WriteString(header.Get("Content-Type")) + buf.WriteString(header.Get("Content-Disposition")) + buf.WriteString(header.Get("Content-Length")) + buf.WriteString(header.Get("Last-Modified")) + } + } + + sum := sha256.Sum256(buf.Bytes()) + return Checksum(fmt.Sprintf("%x", sum)) +} + +// ForHTTPResponse returns a checksum generated from the response's +// request URL and the contents of the response's header. If the header +// contains an Etag, that is used as the primary element. Otherwise, +// other values such as Content-Length and Last-Modified are considered. +func ForHTTPResponse(resp *http.Response) Checksum { + if resp == nil { + return "" + } + + buf := bytes.Buffer{} + if resp.Request != nil && resp.Request.URL != nil { + buf.WriteString(resp.Request.URL.String()) + } + header := resp.Header + if header != nil { + etag := header.Get("Etag") + if etag != "" { + buf.WriteString(etag) + } else { + buf.WriteString(header.Get("Content-Type")) + buf.WriteString(header.Get("Content-Disposition")) + buf.WriteString(header.Get("Content-Length")) + buf.WriteString(header.Get("Last-Modified")) + } + } + + sum := sha256.Sum256(buf.Bytes()) + return Checksum(fmt.Sprintf("%x", sum)) +} diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 43eed24b2..a120f7fcd 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -5,8 +5,10 @@ import ( "bytes" "context" crand "crypto/rand" + "crypto/tls" "io" mrand "math/rand" + "net/http" "os" "path/filepath" "strings" @@ -350,6 +352,16 @@ func RequireDir(dir string) error { return errz.Err(os.MkdirAll(dir, 0o750)) } +// ReadFileToString reads the file at name and returns its contents +// as a string. +func ReadFileToString(name string) (string, error) { + b, err := os.ReadFile(name) + if err != nil { + return "", errz.Err(err) + } + return string(b), nil +} + // DirExists returns true if dir exists and is a directory. func DirExists(dir string) bool { fi, err := os.Stat(dir) @@ -376,3 +388,33 @@ func PrintTree(w io.Writer, loc string, showSize, colorize bool) error { inf.Print(opts) return nil } + +// NewHTTPClient returns a new HTTP client with the specified timeout. +// A timeout of zero means no timeout. If insecureSkipVerify is true, the +// client will skip TLS verification. +// +// REVISIT: Would it be better to just not set a timeout, and instead +// use context.WithTimeout for each request? +func NewHTTPClient(timeout time.Duration, insecureSkipVerify bool) *http.Client { + client := *http.DefaultClient + + var tr *http.Transport + if client.Transport == nil { + tr = (http.DefaultTransport.(*http.Transport)).Clone() + } else { + tr = (client.Transport.(*http.Transport)).Clone() + } + + if tr.TLSClientConfig == nil { + tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS10} + } else { + tr.TLSClientConfig = tr.TLSClientConfig.Clone() + } + + tr.TLSClientConfig.InsecureSkipVerify = insecureSkipVerify + + client.Timeout = timeout + client.Transport = tr + + return &client +} diff --git a/libsq/core/stringz/stringz.go b/libsq/core/stringz/stringz.go index 2e2d48641..424fb8615 100644 --- a/libsq/core/stringz/stringz.go +++ b/libsq/core/stringz/stringz.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "math/rand" + "path/filepath" "reflect" "regexp" "strconv" @@ -738,3 +739,32 @@ func ExecuteTemplate(name, tpl string, data any) (string, error) { func ShellEscape(s string) string { return shellescape.Quote(s) } + +var filenameRegex = regexp.MustCompile(`[^a-zA-Z0-9-_ .(),+]`) + +// SanitizeFilename returns a sanitized version of filename. +// The supplied value should be the base file name, not a path. +func SanitizeFilename(name string) string { + const repl = "_" + + if name == "" { + return "" + } + name = filenameRegex.ReplaceAllString(name, repl) + if name == "" { + return "" + } + + name = filepath.Clean(name) + // Some extra paranoid handling below. + // Note that we know that filename is at least one char long. + trimmed := strings.TrimSpace(name) + switch { + case trimmed == ".": + return strings.Replace(name, ".", repl, 1) + case trimmed == "..": + return strings.Replace(name, "..", repl+repl, 1) + default: + return name + } +} diff --git a/libsq/core/stringz/stringz_test.go b/libsq/core/stringz/stringz_test.go index 8e9aabf5d..6319cd0ed 100644 --- a/libsq/core/stringz/stringz_test.go +++ b/libsq/core/stringz/stringz_test.go @@ -621,3 +621,30 @@ func TestDecimal(t *testing.T) { }) } } + +func TestSanitizeFilename(t *testing.T) { + testCases := []struct { + in string + want string + }{ + {in: "", want: ""}, + {in: " ", want: " "}, + {in: "a", want: "a"}, + {in: "a b", want: "a b"}, + {in: "a b c", want: "a b c"}, + {in: "a b c.txt", want: "a b c.txt"}, + {in: "conin$", want: "conin_"}, + {in: "a+b", want: "a+b"}, + {in: "some (file).txt", want: "some (file).txt"}, + {in: ".", want: "_"}, + {in: "..", want: "__"}, + } + + for i, tc := range testCases { + tc := tc + t.Run(tu.Name(i, tc.in), func(t *testing.T) { + got := stringz.SanitizeFilename(tc.in) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/libsq/source/download.go b/libsq/source/download.go index ab94f3051..2a7796e0e 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -1,6 +1,7 @@ package source import ( + "bytes" "context" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" @@ -8,11 +9,13 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/stringz" "golang.org/x/exp/maps" "io" "log/slog" "mime" "net/http" + "net/http/httputil" "net/url" "os" "path" @@ -23,26 +26,75 @@ import ( "github.com/neilotoole/sq/libsq/source/fetcher" ) -func newDownloader(srcCacheDir, url string) *downloader { - downloadDir := filepath.Join(srcCacheDir, "download") +func newDownloader(c *http.Client, cacheDir, url string) *downloader { return &downloader{ - srcCacheDir: srcCacheDir, - downloadDir: downloadDir, - checksumFile: filepath.Join(srcCacheDir, "download.checksum.txt"), - url: url, + c: c, + cacheDir: cacheDir, + url: url, } } +// download is a helper for getting file contents from a URL, +// and caching the file locally. The structure of cacheDir +// is as follows: +// +// cacheDir/ +// pid.lock +// checksum.txt +// header.txt +// dl/ +// +// +// Let's take a closer look. +// +// - pid.lock is a lock file used to ensure that only one +// process is downloading the file at a time. +// +// - header.txt is a dump of the HTTP response header, included for +// debugging convenience. +// +// - checksum.txt contains a checksum:key pair, where the checksum is +// calculated using checksum.ForHTTPHeader, and the key is the path +// to the downloaded file, e.g. "dl/data.csv". +// +// 67a47a0...a53e3e28154 dl/actor.csv +// +// - The file is downloaded to dl/ instead of into the root +// of cache dir, just to avoid the (remote) possibility of a name +// collision with the other files in cacheDir. The filename is based +// on the HTTP response, incorporating the Content-Disposition header +// if present, or the last path segment of the URL. The filename is +// sanitized. +// +// When downloader.Download is invoked, it appropriately clears the existing +// stored files before proceeding. Likewise, if the download fails, the stored +// files are wiped, to prevent a partial download from being used. type downloader struct { - mu sync.Mutex - srcCacheDir string - downloadDir string - checksumFile string - url string + // FIXME: Use a client that doesn't require SSL? (see fetcher) + c *http.Client + mu sync.Mutex + cacheDir string + url string } func (d *downloader) log(log *slog.Logger) *slog.Logger { - return log.With(lga.URL, d.url, "download_dir", d.downloadDir) + return log.With(lga.URL, d.url, lga.Dir, d.cacheDir) +} + +func (d *downloader) dlDir() string { + return filepath.Join(d.cacheDir, "dl") +} + +func (d *downloader) checksumFile() string { + return filepath.Join(d.cacheDir, "checksum.txt") +} + +func (d *downloader) headerFile() string { + return filepath.Join(d.cacheDir, "header.txt") +} + +func (d *downloader) lockFile() string { + return filepath.Join(d.cacheDir, "pid.lock") } // Download downloads the file at the URL to the download dir, and also writes @@ -58,10 +110,21 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 log := d.log(lg.FromContext(ctx)) - if err = ioz.RequireDir(d.downloadDir); err != nil { + dlDir := d.dlDir() + // Clear the download dir. + if err = os.RemoveAll(dlDir); err != nil { + return written, "", errz.Wrapf(err, "could not clear download dir for: %s", d.url) + } + + if err = ioz.RequireDir(dlDir); err != nil { return written, "", errz.Wrapf(err, "could not create download dir for: %s", d.url) } + // Make sure the header file is cleared. + if err = os.RemoveAll(d.headerFile()); err != nil { + return written, "", errz.Wrapf(err, "could not clear header file for: %s", d.url) + } + var cancelFn context.CancelFunc ctx, cancelFn = context.WithCancel(ctx) defer cancelFn() @@ -70,8 +133,7 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 return written, "", errz.Wrapf(err, "download new request failed for: %s", d.url) } - // FIXME: Use a client that doesn't require SSL (see fetcher) - resp, err := http.DefaultClient.Do(req) + resp, err := d.c.Do(req) if err != nil { return written, "", errz.Wrapf(err, "download failed for: %s", d.url) } @@ -81,6 +143,10 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 } }() + if err = d.writeHeaderFile(resp); err != nil { + return written, "", err + } + if resp.StatusCode != http.StatusOK { return written, "", errz.Errorf("download failed with %s for %s", resp.Status, d.url) } @@ -90,7 +156,7 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 filename = "download" } - fp = filepath.Join(d.downloadDir, filename) + fp = filepath.Join(dlDir, filename) f, err := os.OpenFile(fp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) if err != nil { return written, "", errz.Wrapf(err, "could not create download file for: %s", d.url) @@ -112,16 +178,42 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 return written, "", errz.Wrapf(err, "failed to close download file: %s", fp) } + sum := checksum.ForHTTPResponse(resp) + if err = checksum.WriteFile(d.checksumFile(), sum, filepath.Join("dl", filename)); err != nil { + lg.WarnIfFuncError(log, lgm.RemoveFile, func() error { return errz.Err(os.Remove(fp)) }) + } + log.Info("Wrote download file", lga.Written, written, lga.File, fp) return written, fp, nil } + +func (d *downloader) writeHeaderFile(resp *http.Response) error { + b, err := httputil.DumpResponse(resp, false) + if err != nil { + return errz.Wrapf(err, "failed to dump HTTP response for: %s", d.url) + } + + if len(b) == 0 { + return errz.Errorf("empty HTTP response for: %s", d.url) + } + + // Add a custom field just for human consumption convenience. + b = bytes.TrimSuffix(b, []byte("\r\n")) + b = append(b, "X-Sq-Downloaded-From: "+d.url+"\r\n"...) + + if err = os.WriteFile(d.headerFile(), b, os.ModePerm); err != nil { + return errz.Wrapf(err, "failed to store HTTP response header for: %s", d.url) + } + return nil +} + func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum, fp string) { d.mu.Lock() defer d.mu.Unlock() log := d.log(lg.FromContext(ctx)) - - fi, err := os.Stat(d.downloadDir) + dlDir := d.dlDir() + fi, err := os.Stat(dlDir) if err != nil { log.Debug("not cached: can't stat download dir") return false, "", "" @@ -131,13 +223,13 @@ func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum return false, "", "" } - fi, err = os.Stat(d.checksumFile) + fi, err = os.Stat(d.checksumFile()) if err != nil { log.Debug("not cached: can't stat download checksum file") return false, "", "" } - checksums, err := checksum.ReadFile(d.checksumFile) + checksums, err := checksum.ReadFile(d.checksumFile()) if err != nil { log.Debug("not cached: can't read download checksum file") return false, "", "" @@ -155,7 +247,7 @@ func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum return false, "", "" } - downloadFile := filepath.Join(d.downloadDir, key) + downloadFile := filepath.Join(dlDir, key) fi, err = os.Stat(downloadFile) if err != nil { @@ -257,20 +349,27 @@ func (fs *Files) fetch(ctx context.Context, loc string) (fpath string, err error // getDownloadFilename returns the filename to use for a download. // It first checks the Content-Disposition header, and if that's -// not present, it uses the last path segment of the URL. +// not present, it uses the last path segment of the URL. The +// filename is sanitized. // It's possible that the returned value will be empty string; the // caller should handle that situation themselves. func getDownloadFilename(resp *http.Response) string { var filename string + if resp == nil || resp.Header == nil { + return "" + } dispHeader := resp.Header.Get("Content-Disposition") if dispHeader != "" { if _, params, err := mime.ParseMediaType(dispHeader); err == nil { filename = params["filename"] } } + if filename == "" { filename = path.Base(resp.Request.URL.Path) + } else { + filename = filepath.Base(filename) } - return filename + return stringz.SanitizeFilename(filename) } diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go index 845f41310..57ec84044 100644 --- a/libsq/source/download_test.go +++ b/libsq/source/download_test.go @@ -6,6 +6,7 @@ import ( "github.com/neilotoole/slogt" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/testh/proj" + "github.com/neilotoole/sq/testh/tu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "net/http" @@ -26,8 +27,9 @@ func TestGetRemoteChecksum(t *testing.T) { } const ( - urlActorCSV = "https://sq.io/testdata/actor.csv" - sizeActorCSV = int64(7641) + urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" + urlActorCSV = "https://sq.io/testdata/actor.csv" + sizeActorCSV = int64(7641) ) func TestFetchHTTPHeader_sqio(t *testing.T) { @@ -38,18 +40,26 @@ func TestFetchHTTPHeader_sqio(t *testing.T) { // TODO } -func TestDownloader(t *testing.T) { +func TestDownloader_Download(t *testing.T) { const u = "https://sq.io/testdata/actor.csv" ctx := lg.NewContext(context.Background(), slogt.New(t)) cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-1")) require.NoError(t, err) - dl := newDownloader(cacheDir, urlActorCSV) + t.Logf("cacheDir: %s", cacheDir) + dl := newDownloader(http.DefaultClient, cacheDir, urlActorCSV) buf := &bytes.Buffer{} written, fp, err := dl.Download(ctx, buf) require.NoError(t, err) require.FileExists(t, fp) require.Equal(t, sizeActorCSV, written) + require.Equal(t, sizeActorCSV, int64(buf.Len())) + + s := tu.ReadFileToString(t, dl.headerFile()) + t.Logf("header.txt\n\n" + s) + + s = tu.ReadFileToString(t, dl.checksumFile()) + t.Logf("checksum.txt\n\n" + s) // TODO } diff --git a/libsq/source/files.go b/libsq/source/files.go index 809699aa9..3ca6d9896 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -4,6 +4,7 @@ import ( "context" "io" "log/slog" + "net/http" "net/url" "os" "path/filepath" @@ -57,8 +58,9 @@ type Files struct { detectFns []DriverDetectFunc } -// NewFiles returns a new Files instance. -func NewFiles(ctx context.Context, tmpDir, cacheDir string) (*Files, error) { +// NewFiles returns a new Files instance. If c is nil, http.DefaultClient is +// used. If cleanFscache is true, the fscache is cleaned on Files.Close. +func NewFiles(ctx context.Context, c *http.Client, tmpDir, cacheDir string, cleanFscache bool) (*Files, error) { if tmpDir == "" { return nil, errz.Errorf("tmpDir is empty") } @@ -66,6 +68,10 @@ func NewFiles(ctx context.Context, tmpDir, cacheDir string) (*Files, error) { return nil, errz.Errorf("cacheDir is empty") } + if c == nil { + c = http.DefaultClient + } + fs := &Files{ cacheDir: cacheDir, fscacheEntryMetas: make(map[string]*fscacheEntryMeta), @@ -78,12 +84,18 @@ func NewFiles(ctx context.Context, tmpDir, cacheDir string) (*Files, error) { // on cleanup (unless something bad happens and sq doesn't // get a chance to clean up). But, why take the chance; we'll just give // fcache a unique dir each time. - fcacheTmpDir := filepath.Join(cacheDir, "fscache", strconv.Itoa(os.Getpid()), stringz.Uniq32()) - if err := ioz.RequireDir(fcacheTmpDir); err != nil { + fscacheTmpDir := filepath.Join(cacheDir, "fscache", strconv.Itoa(os.Getpid())+"_"+stringz.Uniq32()) + if err := ioz.RequireDir(fscacheTmpDir); err != nil { return nil, errz.Err(err) } - fcache, err := fscache.New(fcacheTmpDir, os.ModePerm, time.Hour) + if cleanFscache { + fs.clnup.AddE(func() error { + return errz.Wrap(os.RemoveAll(fscacheTmpDir), "remove fscache dir") + }) + } + + fcache, err := fscache.New(fscacheTmpDir, os.ModePerm, time.Hour) if err != nil { return nil, errz.Err(err) } @@ -95,6 +107,7 @@ func NewFiles(ctx context.Context, tmpDir, cacheDir string) (*Files, error) { // Filesize returns the file size of src.Location. If the source is being // loaded asynchronously, this function may block until loading completes. +// An error is returned if src is not a document/file source. func (fs *Files) Filesize(ctx context.Context, src *Source) (size int64, err error) { locTyp := getLocType(src.Location) switch locTyp { diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index 55f9ad90c..f104581c0 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -57,7 +57,7 @@ func TestFiles_Type(t *testing.T) { t.Run(tu.Name(source.RedactLocation(tc.loc)), func(t *testing.T) { ctx := lg.NewContext(context.Background(), slogt.New(t)) - fs, err := source.NewFiles(ctx, tu.TempDir(t), tu.CacheDir(t)) + fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -99,7 +99,7 @@ func TestFiles_DetectType(t *testing.T) { t.Run(filepath.Base(tc.loc), func(t *testing.T) { ctx := lg.NewContext(context.Background(), slogt.New(t)) - fs, err := source.NewFiles(ctx, tu.TempDir(t), tu.CacheDir(t)) + fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -159,7 +159,7 @@ func TestFiles_NewReader(t *testing.T) { Location: proj.Abs(fpath), } - fs, err := source.NewFiles(ctx, tu.TempDir(t), tu.CacheDir(t)) + fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) require.NoError(t, err) g := &errgroup.Group{} diff --git a/libsq/source/internal_test.go b/libsq/source/internal_test.go index b3a408865..2632391e0 100644 --- a/libsq/source/internal_test.go +++ b/libsq/source/internal_test.go @@ -30,7 +30,7 @@ var ( func TestFiles_Open(t *testing.T) { ctx := lg.NewContext(context.Background(), slogt.New(t)) - fs, err := NewFiles(ctx, tu.TempDir(t), tu.CacheDir(t)) + fs, err := NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, fs.Close()) }) diff --git a/testh/testh.go b/testh/testh.go index b55bca402..0e48f2949 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -161,7 +161,7 @@ func (h *Helper) init() { h.registry = driver.NewRegistry(log) var err error - h.files, err = source.NewFiles(h.Context, tu.TempDir(h.T), tu.CacheDir(h.T)) + h.files, err = source.NewFiles(h.Context, nil, tu.TempDir(h.T), tu.CacheDir(h.T), true) require.NoError(h.T, err) h.Cleanup.Add(func() { diff --git a/testh/tu/tutil.go b/testh/tu/tutil.go index 46bcf55e5..8bb99cb18 100644 --- a/testh/tu/tutil.go +++ b/testh/tu/tutil.go @@ -3,6 +3,7 @@ package tu import ( "fmt" + "github.com/neilotoole/sq/libsq/core/ioz" "io" "os" "path/filepath" @@ -389,3 +390,11 @@ func TempDir(t testing.TB) string { func CacheDir(t testing.TB) string { return filepath.Join(t.TempDir(), "sq", "cache") } + +// ReadFileToString invokes ioz.ReadFileToString, failing t if +// an error occurs. +func ReadFileToString(t testing.TB, name string) string { + s, err := ioz.ReadFileToString(name) + require.NoError(t, err) + return s +} From 890c9576696e4e32761ec3792edbd22339ccdc56 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 7 Dec 2023 18:37:24 -0700 Subject: [PATCH 069/195] lock testing --- cli/cli.go | 10 ++- cli/cmd_x.go | 70 ++++++++++++++++++ cli/cmd_xtest.go | 81 -------------------- cli/run.go | 2 +- libsq/core/errz/errz.go | 27 +++++++ libsq/core/errz/errz_test.go | 30 +++++++- libsq/core/ioz/ioz.go | 2 +- libsq/core/ioz/ioz_test.go | 2 +- libsq/core/ioz/lockfile/lockfile.go | 94 ++++++++++++++++++++++++ libsq/core/ioz/lockfile/lockfile_test.go | 34 +++++++++ libsq/core/lg/lga/lga.go | 2 + libsq/driver/grips.go | 2 +- libsq/source/download.go | 31 ++++---- libsq/source/download_test.go | 28 +++---- libsq/source/files.go | 64 ++++++++++++++-- testh/tu/tutil.go | 2 +- 16 files changed, 346 insertions(+), 135 deletions(-) create mode 100644 cli/cmd_x.go delete mode 100644 cli/cmd_xtest.go create mode 100644 libsq/core/ioz/lockfile/lockfile.go create mode 100644 libsq/core/ioz/lockfile/lockfile_test.go diff --git a/cli/cli.go b/cli/cli.go index 369f4086d..4fc8082fa 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -78,9 +78,8 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { ctx = options.NewContext(ctx, ru.Config.Options) log := lg.FromContext(ctx) log.Info("EXECUTE", "args", strings.Join(args, " ")) - log.Debug("Build info", "build", buildinfo.Get()) - log.Debug("Config", - "config.version", ru.Config.Version, + log.Info("Build info", "build", buildinfo.Get()) + log.Info("Config", "config_version", ru.Config.Version, lga.Path, ru.ConfigStore.Location(), ) @@ -237,7 +236,10 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { addCmd(ru, rootCmd, newCompletionCmd()) addCmd(ru, rootCmd, newVersionCmd()) addCmd(ru, rootCmd, newManCmd()) - addCmd(ru, rootCmd, newXTestCmd()) + + xCmd := addCmd(ru, rootCmd, newXCmd()) + addCmd(ru, xCmd, newXLockSrcCmd()) + addCmd(ru, xCmd, newXTestCmd()) return rootCmd } diff --git a/cli/cmd_x.go b/cli/cmd_x.go new file mode 100644 index 000000000..4a69da5cf --- /dev/null +++ b/cli/cmd_x.go @@ -0,0 +1,70 @@ +package cli + +import ( + "bufio" + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/neilotoole/sq/cli/run" +) + +func newXCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "x", + Short: "Run hidden/test commands", + Hidden: true, + } + + return cmd +} + +func newXLockSrcCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "lock-src-cache @src", + Short: "Test source cache locking", + Hidden: true, + Args: cobra.ExactArgs(1), + ValidArgsFunction: completeHandle(1), + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + ru := run.FromContext(ctx) + src, err := ru.Config.Collection.Get(args[0]) + if err != nil { + return err + } + + timeout := time.Minute * 20 + lock, err := ru.Files.CacheLockFor(src) + if err != nil { + return err + } + fmt.Fprintf(ru.Out, "Locking cache for source %s with timeout %s for %q [%d]\n\n %s\n\n", + src.Handle, timeout, os.Args[0], os.Getpid(), lock) + + err = lock.Lock(ctx, timeout) + if err != nil { + return err + } + + fmt.Fprintf(ru.Out, "Cache lock acquired for %s\n", src.Handle) + fmt.Fprintln(ru.Out, "Press ENTER to release lock and exit.") + + // Wait for ENTER on stdin + buf := bufio.NewReader(os.Stdin) + fmt.Print(" > ") + _, _ = buf.ReadBytes('\n') + fmt.Fprintf(ru.Out, "Releasing cache lock for %s\n", src.Handle) + if err = lock.Unlock(); err != nil { + return err + } + + fmt.Fprintf(ru.Out, "Cache lock released for %s\n", src.Handle) + return nil + }, + } + + return cmd +} diff --git a/cli/cmd_xtest.go b/cli/cmd_xtest.go deleted file mode 100644 index 3f12cfab1..000000000 --- a/cli/cmd_xtest.go +++ /dev/null @@ -1,81 +0,0 @@ -package cli - -import ( - "context" - "fmt" - "math/rand" - "time" - - "github.com/spf13/cobra" - - "github.com/neilotoole/sq/cli/buildinfo" - "github.com/neilotoole/sq/cli/flag" - "github.com/neilotoole/sq/cli/hostinfo" - "github.com/neilotoole/sq/cli/run" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/progress" -) - -func newXTestCmd() *cobra.Command { - cmd := &cobra.Command{ - Use: "xtest", - Short: "Execute some internal tests", - Hidden: true, - RunE: execXTestMbp, - } - - cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) - cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage) - cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage) - - return cmd -} - -func execXTestMbp(cmd *cobra.Command, _ []string) error { - ctx := cmd.Context() - ru := run.FromContext(ctx) - - fmt.Fprintln(ru.Out, "Hello, world!") - - pb := progress.New(ctx, ru.ErrOut, 1*time.Millisecond, progress.DefaultColors()) - ctx = progress.NewContext(ctx, pb) - - if err := doProgressByteCounterRead(ctx); err != nil { - return err - } - - return ru.Writers.Version.Version(buildinfo.Get(), buildinfo.Get().Version, hostinfo.Get()) -} - -func doProgressByteCounterRead(ctx context.Context) error { - pb := progress.FromContext(ctx) - - bar := pb.NewByteCounter("Ingest data test", -1) - defer bar.Stop() - maxSleep := 100 * time.Millisecond - - lr := ioz.LimitRandReader(100000) - b := make([]byte, 1024) - -LOOP: - for { - select { - case <-ctx.Done(): - bar.Stop() - break LOOP - default: - } - - n, err := lr.Read(b) - if err != nil { - bar.Stop() - break - } - - bar.IncrBy(n) - time.Sleep(time.Duration(rand.Intn(10)+1) * maxSleep / 10) //nolint:gosec - } - - pb.Stop() - return nil -} diff --git a/cli/run.go b/cli/run.go index 9d3d3dee7..a2425ea2f 100644 --- a/cli/run.go +++ b/cli/run.go @@ -2,7 +2,6 @@ package cli import ( "context" - "github.com/neilotoole/sq/libsq/core/ioz" "io" "log/slog" "os" @@ -26,6 +25,7 @@ import ( "github.com/neilotoole/sq/drivers/xlsx" "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/slogbuf" diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index b26f1088f..98813d95a 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -100,3 +100,30 @@ func IsErrContextDeadlineExceeded(err error) bool { func Tuple[T any](t T, err error) (T, error) { return t, Err(err) } + +// As is a convenience wrapper around errors.As. +// +// _, err := os.Open("non-existing") +// ok, pathErr := errz.As[*fs.PathError](err) +// require.True(t, ok) +// require.Equal(t, "non-existing", pathErr.Path) +// +// Under the covers, As delegates to errors.As. +func As[E error](err error) (bool, E) { + var target E + if errors.As(err, &target) { + return true, target + } + return false, target +} + +// IsType returns true if err, or an error in its tree, if of type E. +// +// _, err := os.Open("non-existing") +// isPathErr := errz.IsType[*fs.PathError](err) +// +// Under the covers, IsType uses errors.As. +func IsType[E error](err error) bool { + var target E + return errors.As(err, &target) +} diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 988e109db..a5e03e915 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -4,6 +4,10 @@ import ( "database/sql" "errors" "fmt" + "github.com/neilotoole/sq/libsq/core/stringz" + "io/fs" + "net/url" + "os" "testing" "github.com/stretchr/testify/require" @@ -21,7 +25,7 @@ func TestIs(t *testing.T) { require.True(t, errors.Is(err, sql.ErrNoRows)) } -func TestAs(t *testing.T) { +func TestWrapCauseAs(t *testing.T) { var originalErr error //nolint:gosimple originalErr = &CustomError{msg: "huzzah"} @@ -94,3 +98,27 @@ func TestIsErrNoData(t *testing.T) { require.True(t, errz.IsErrNoData(err)) require.Equal(t, "me doesn't exist", err.Error()) } + +func TestIsType(t *testing.T) { + _, err := os.Open(stringz.Uniq32() + "-non-existing") + require.Error(t, err) + t.Logf("err: %T %v", err, err) + + got := errz.IsType[*fs.PathError](err) + require.True(t, got) + + got = errz.IsType[*url.Error](err) + require.False(t, got) +} + +func TestAs(t *testing.T) { + fp := stringz.Uniq32() + "-non-existing" + _, err := os.Open(fp) + require.Error(t, err) + t.Logf("err: %T %v", err, err) + + ok, pathErr := errz.As[*fs.PathError](err) + require.True(t, ok) + require.NotNil(t, pathErr) + require.Equal(t, fp, pathErr.Path) +} diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index a120f7fcd..b073949d2 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -406,7 +406,7 @@ func NewHTTPClient(timeout time.Duration, insecureSkipVerify bool) *http.Client } if tr.TLSClientConfig == nil { - tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS10} + tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS10} //nolint:gosec } else { tr.TLSClientConfig = tr.TLSClientConfig.Clone() } diff --git a/libsq/core/ioz/ioz_test.go b/libsq/core/ioz/ioz_test.go index c9acfdac3..4f5ecc63e 100644 --- a/libsq/core/ioz/ioz_test.go +++ b/libsq/core/ioz/ioz_test.go @@ -2,7 +2,6 @@ package ioz_test import ( "bytes" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" "io" "os" "sync" @@ -13,6 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" ) func TestMarshalYAML(t *testing.T) { diff --git a/libsq/core/ioz/lockfile/lockfile.go b/libsq/core/ioz/lockfile/lockfile.go new file mode 100644 index 000000000..aa123603f --- /dev/null +++ b/libsq/core/ioz/lockfile/lockfile.go @@ -0,0 +1,94 @@ +// Package lockfile implements a pid lock file mechanism. +package lockfile + +import ( + "context" + "github.com/neilotoole/sq/libsq/core/ioz" + "path/filepath" + "time" + + "github.com/nightlyone/lockfile" + + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/retry" +) + +// Lockfile is a pid file which can be locked. +type Lockfile string + +// New returns a new Lockfile instance. Arg fp must be +// an absolute path (but the path may not exist). +func New(fp string) (Lockfile, error) { + lf, err := lockfile.New(fp) + if err != nil { + return "", errz.Err(err) + } + return Lockfile(lf), nil +} + +// Lock attempts to acquire the lockfile, retrying if necessary, +// until the timeout expires. If timeout is zero, retry will not occur. +// On success, nil is returned. An error is returned if the lock cannot +// be acquired for any reason. +func (l Lockfile) Lock(ctx context.Context, timeout time.Duration) error { + log := lg.FromContext(ctx).With(lga.Lock, l, lga.Timeout, timeout) + + dir := filepath.Dir(string(l)) + if err := ioz.RequireDir(dir); err != nil { + return errz.Wrapf(err, "failed create parent dir of cache lock: %s", string(l)) + } + + if timeout == 0 { + if err := lockfile.Lockfile(l).TryLock(); err != nil { + log.Warn("Failed to acquire pid lock", lga.Err, err) + return errz.Wrapf(err, "failed to acquire pid lock: %s", l) + } + log.Debug("Acquired pid lock") + return nil + } + + start, attempts := time.Now(), 0 + + err := retry.Do(ctx, timeout, + func() error { + log.Debug("Attempting to acquire pid lock", lga.Attempts, attempts) + + err := lockfile.Lockfile(l).TryLock() + attempts++ + if err == nil { + log.Debug("Acquired pid lock", lga.Attempts, attempts) + return nil + } + + log.Debug("Failed to acquire pid lock", lga.Attempts, lga.Err, err) + return err + }, + errz.IsType[lockfile.TemporaryError], + ) + + elapsed := time.Since(start) + if err != nil { + log.Warn("Failed to acquire pid lock", + lga.Attempts, attempts, + lga.Elapsed, elapsed, + lga.Err, err, + ) + return errz.Wrapf(err, "failed to acquire pid lock: %d attempts in %s: %s", + attempts, time.Since(start), l) + } + + return nil +} + +// Unlock a lock, if we owned it. Returns any error that +// happened during release of lock. +func (l Lockfile) Unlock() error { + return errz.Err(lockfile.Lockfile(l).Unlock()) +} + +// String returns the Lockfile's absolute path. +func (l Lockfile) String() string { + return string(l) +} diff --git a/libsq/core/ioz/lockfile/lockfile_test.go b/libsq/core/ioz/lockfile/lockfile_test.go new file mode 100644 index 000000000..086011e21 --- /dev/null +++ b/libsq/core/ioz/lockfile/lockfile_test.go @@ -0,0 +1,34 @@ +package lockfile_test + +import ( + "context" + "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/stretchr/testify/require" + "path/filepath" + "testing" + "time" +) + +// FIXME: Duh, this can't work, because we're in the same pid. +func TestLockfile(t *testing.T) { + ctx := lg.NewContext(context.Background(), slogt.New(t)) + + pidfile := filepath.Join(t.TempDir(), "lock.pid") + lock, err := lockfile.New(pidfile) + require.NoError(t, err) + require.Equal(t, pidfile, string(lock)) + + require.NoError(t, lock.Lock(ctx, 0), + "should be able to acquire lock immediately") + time.AfterFunc(time.Second*100, func() { + require.NoError(t, lock.Unlock()) + }) + + err = lock.Lock(ctx, time.Second) + require.Error(t, err, "not enough time to acquire the lock") + + err = lock.Lock(ctx, time.Second*10) + require.NoError(t, err, "should be able to acquire the lock") +} diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index 4ca3a3d74..72eb96d81 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -6,6 +6,7 @@ const ( Action = "action" After = "after" Alt = "alt" + Attempts = "attempts" Before = "before" Catalog = "catalog" Cmd = "cmd" @@ -50,6 +51,7 @@ const ( Schema = "schema" Table = "table" Target = "target" + Timeout = "timeout" To = "to" Type = "type" Line = "line" diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 37cc2da50..319232cbd 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -3,7 +3,6 @@ package driver import ( "context" "errors" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" "log/slog" "path/filepath" "strings" @@ -15,6 +14,7 @@ import ( "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" diff --git a/libsq/source/download.go b/libsq/source/download.go index 2a7796e0e..fec7ea36f 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -3,14 +3,6 @@ package source import ( "bytes" "context" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/lg/lgm" - "github.com/neilotoole/sq/libsq/core/stringz" - "golang.org/x/exp/maps" "io" "log/slog" "mime" @@ -22,7 +14,16 @@ import ( "path/filepath" "sync" + "golang.org/x/exp/maps" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source/fetcher" ) @@ -93,10 +94,6 @@ func (d *downloader) headerFile() string { return filepath.Join(d.cacheDir, "header.txt") } -func (d *downloader) lockFile() string { - return filepath.Join(d.cacheDir, "pid.lock") -} - // Download downloads the file at the URL to the download dir, and also writes // the file to dest, and returns the file path of the downloaded file. // It is the caller's responsibility to close dest. If an appropriate file name @@ -223,8 +220,7 @@ func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum return false, "", "" } - fi, err = os.Stat(d.checksumFile()) - if err != nil { + if _, err = os.Stat(d.checksumFile()); err != nil { log.Debug("not cached: can't stat download checksum file") return false, "", "" } @@ -249,8 +245,7 @@ func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum downloadFile := filepath.Join(dlDir, key) - fi, err = os.Stat(downloadFile) - if err != nil { + if _, err = os.Stat(downloadFile); err != nil { log.Debug("not cached: can't stat download file referenced in checksum file", lga.File, key) return false, "", "" } @@ -276,10 +271,10 @@ func fetchHTTPHeader(ctx context.Context, u string) (header http.Header, err err } switch resp.StatusCode { - case http.StatusOK: - return resp.Header, nil default: return nil, errz.Errorf("unexpected HTTP status (%s) for HEAD: %s", resp.Status, u) + case http.StatusOK: + return resp.Header, nil case http.StatusMethodNotAllowed: } diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go index 57ec84044..93b3504d4 100644 --- a/libsq/source/download_test.go +++ b/libsq/source/download_test.go @@ -3,17 +3,20 @@ package source import ( "bytes" "context" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/testh/proj" - "github.com/neilotoole/sq/testh/tu" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "net/http" "net/http/httptest" "path/filepath" "strconv" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/neilotoole/slogt" + + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/testh/proj" + "github.com/neilotoole/sq/testh/tu" ) func TestGetRemoteChecksum(t *testing.T) { @@ -41,7 +44,6 @@ func TestFetchHTTPHeader_sqio(t *testing.T) { } func TestDownloader_Download(t *testing.T) { - const u = "https://sq.io/testdata/actor.csv" ctx := lg.NewContext(context.Background(), slogt.New(t)) cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-1")) @@ -72,11 +74,10 @@ func TestFetchHTTPHeader_HEAD_fallback_GET(t *testing.T) { return } - w.Header().Set(http.CanonicalHeaderKey("Content-Length"), strconv.Itoa(len(b))) + w.Header().Set("Content-Length", strconv.Itoa(len(b))) w.WriteHeader(http.StatusOK) _, err := w.Write(b) require.NoError(t, err) - })) t.Cleanup(srvr.Close) @@ -85,13 +86,4 @@ func TestFetchHTTPHeader_HEAD_fallback_GET(t *testing.T) { header, err := fetchHTTPHeader(context.Background(), u) assert.NoError(t, err) assert.NotNil(t, header) - - //u := "https://sq.io/testdata/actor.csv" - // - //header, allowed, err := fetchHTTPHeader(context.Background(), u) - //require.NoError(t, err) - //require.True(t, allowed) - //require.NotNil(t, header) - // - //// TODO } diff --git a/libsq/source/files.go b/libsq/source/files.go index 3ca6d9896..43199d0d0 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -18,6 +18,7 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" @@ -38,11 +39,12 @@ import ( // if we're reading long-running pipe from stdin). This entire thing // needs to be revisited. Maybe Files even becomes a fs.FS. type Files struct { - mu sync.Mutex - log *slog.Logger - cacheDir string - tempDir string - clnup *cleanup.Cleanup + mu sync.Mutex + log *slog.Logger + cacheDir string + tempDir string + clnup *cleanup.Cleanup + httpClient *http.Client // fscache is used to cache files, providing convenient access // to multiple readers via Files.newReader. @@ -73,6 +75,7 @@ func NewFiles(ctx context.Context, c *http.Client, tmpDir, cacheDir string, clea } fs := &Files{ + httpClient: c, cacheDir: cacheDir, fscacheEntryMetas: make(map[string]*fscacheEntryMeta), tempDir: tmpDir, @@ -95,13 +98,13 @@ func NewFiles(ctx context.Context, c *http.Client, tmpDir, cacheDir string, clea }) } - fcache, err := fscache.New(fscacheTmpDir, os.ModePerm, time.Hour) + fsc, err := fscache.New(fscacheTmpDir, os.ModePerm, time.Hour) if err != nil { return nil, errz.Err(err) } - fs.clnup.AddE(fcache.Clean) - fs.fscache = fcache + fs.clnup.AddE(fsc.Clean) + fs.fscache = fsc return fs, nil } @@ -297,6 +300,47 @@ func (fs *Files) Open(ctx context.Context, src *Source) (io.ReadCloser, error) { return fs.newReader(ctx, src.Location) } +// NewLock returns a new source.Lock instance. +func NewLock(src *Source, pidfile string) (Lock, error) { + lf, err := lockfile.New(pidfile) + if err != nil { + return Lock{}, errz.Err(err) + } + + return Lock{ + Lockfile: lf, + src: src, + }, nil +} + +type Lock struct { + lockfile.Lockfile + src *Source +} + +func (l Lock) Source() *Source { + return l.src +} + +func (l Lock) String() string { + return l.src.Handle + ": " + string(l.Lockfile) +} + +// CacheLockFor returns the lock file for src's cache. +func (fs *Files) CacheLockFor(src *Source) (lockfile.Lockfile, error) { + cacheDir, err := fs.CacheDirFor(src) + if err != nil { + return "", errz.Wrapf(err, "cache lock for %s", src.Handle) + } + + lf, err := lockfile.New(filepath.Join(cacheDir, "pid.lock")) + if err != nil { + return "", errz.Wrapf(err, "cache lock for %s", src.Handle) + } + + return lf, nil +} + // OpenFunc returns a func that invokes fs.Open for src.Location. func (fs *Files) OpenFunc(src *Source) FileOpenFunc { return func(ctx context.Context) (io.ReadCloser, error) { @@ -305,6 +349,10 @@ func (fs *Files) OpenFunc(src *Source) FileOpenFunc { } // ReadAll is a convenience method to read the bytes of a source. +// +// FIXME: Delete Files.ReadAll? +// +// Deprecated: Files.ReadAll is not in use. We can probably delete it. func (fs *Files) ReadAll(ctx context.Context, src *Source) ([]byte, error) { // fs.mu.Lock() r, err := fs.newReader(ctx, src.Location) diff --git a/testh/tu/tutil.go b/testh/tu/tutil.go index 8bb99cb18..543284df0 100644 --- a/testh/tu/tutil.go +++ b/testh/tu/tutil.go @@ -3,7 +3,6 @@ package tu import ( "fmt" - "github.com/neilotoole/sq/libsq/core/ioz" "io" "os" "path/filepath" @@ -18,6 +17,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/stringz" ) From d9f309ed415cc1fcb7a27d0e73fa79438e96f3c9 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 7 Dec 2023 19:04:21 -0700 Subject: [PATCH 070/195] wip --- cli/cli.go | 1 - cli/cmd_x.go | 20 ++++++++++++---- libsq/source/files.go | 54 +++++++++++++++++++++++++++++++++++++++---- 3 files changed, 65 insertions(+), 10 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index 4fc8082fa..d8bd77714 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -239,7 +239,6 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { xCmd := addCmd(ru, rootCmd, newXCmd()) addCmd(ru, xCmd, newXLockSrcCmd()) - addCmd(ru, xCmd, newXTestCmd()) return rootCmd } diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 4a69da5cf..ec9638f9f 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -52,10 +52,22 @@ func newXLockSrcCmd() *cobra.Command { fmt.Fprintf(ru.Out, "Cache lock acquired for %s\n", src.Handle) fmt.Fprintln(ru.Out, "Press ENTER to release lock and exit.") - // Wait for ENTER on stdin - buf := bufio.NewReader(os.Stdin) - fmt.Print(" > ") - _, _ = buf.ReadBytes('\n') + done := make(chan struct{}) + go func() { + // Wait for ENTER on stdin + buf := bufio.NewReader(os.Stdin) + fmt.Print(" > ") + _, _ = buf.ReadBytes('\n') + close(done) + }() + + select { + case <-done: + fmt.Fprintln(ru.Out, "ENTER received, releasing lock") + case <-ctx.Done(): + fmt.Fprintln(ru.Out, "\nContext done, releasing lock") + } + fmt.Fprintf(ru.Out, "Releasing cache lock for %s\n", src.Handle) if err = lock.Unlock(); err != nil { return err diff --git a/libsq/source/files.go b/libsq/source/files.go index 43199d0d0..0db0b1999 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -98,13 +98,14 @@ func NewFiles(ctx context.Context, c *http.Client, tmpDir, cacheDir string, clea }) } - fsc, err := fscache.New(fscacheTmpDir, os.ModePerm, time.Hour) - if err != nil { + var err error + if fs.fscache, err = fscache.New(fscacheTmpDir, os.ModePerm, time.Hour); err != nil { return nil, errz.Err(err) } + fs.clnup.AddE(fs.fscache.Clean) + + fs.clnup.Add(func() { fs.sweepCacheDir(ctx) }) - fs.clnup.AddE(fsc.Clean) - fs.fscache = fsc return fs, nil } @@ -491,7 +492,6 @@ func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) // Close closes any open resources. func (fs *Files) Close() error { fs.log.Debug("Files.Close invoked: executing clean funcs", lga.Count, fs.clnup.Len()) - return fs.clnup.Run() } @@ -500,6 +500,50 @@ func (fs *Files) CleanupE(fn func() error) { fs.clnup.AddE(fn) } +func (fs *Files) sweepCacheDir(ctx context.Context) { + dir := fs.cacheDir + log := lg.FromContext(ctx).With(lga.Dir, dir) + log.Debug("Sweeping cache dir") + var count int + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if err != nil { + log.Warn("Problem sweeping cache dir", lga.Path, path, lga.Err, err) + return nil + } + + if !info.IsDir() { + return nil + } + + files, err := os.ReadDir(path) + if err != nil { + log.Warn("Problem reading dir", lga.Dir, path, lga.Err, err) + return nil + } + + if len(files) != 0 { + return nil + } + + err = os.Remove(path) + if err != nil { + log.Warn("Problem removing empty dir", lga.Dir, path, lga.Err, err) + } + count++ + + return nil + }) + if err != nil { + log.Warn("Problem sweeping cache dir", lga.Dir, dir, lga.Err, err) + } + log.Info("Swept cache dir", lga.Dir, dir, lga.Count, count) +} + // FileOpenFunc returns a func that opens a ReadCloser. The caller // is responsible for closing the returned ReadCloser. type FileOpenFunc func(ctx context.Context) (io.ReadCloser, error) From c1cf55d185ddf403e58460650804fd155c4c6bd9 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 7 Dec 2023 19:08:48 -0700 Subject: [PATCH 071/195] wip --- libsq/source/files.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/libsq/source/files.go b/libsq/source/files.go index 0db0b1999..40328c98f 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -104,7 +104,8 @@ func NewFiles(ctx context.Context, c *http.Client, tmpDir, cacheDir string, clea } fs.clnup.AddE(fs.fscache.Clean) - fs.clnup.Add(func() { fs.sweepCacheDir(ctx) }) + // REVISIT: We could automatically sweep the cache dir on Close? + // fs.clnup.Add(func() { fs.sweepCacheDir(ctx) }) return fs, nil } From b9a78afdeb81722a307d811bfc85e3d7cb7b59b2 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 7 Dec 2023 21:17:33 -0700 Subject: [PATCH 072/195] fiddling with detectors --- README.md | 1 + cli/cmd_inspect.go | 4 -- cli/cmd_x.go | 4 +- cli/run.go | 1 - drivers/csv/csv.go | 5 -- drivers/xlsx/ingest.go | 1 - libsq/core/errz/errz_test.go | 3 +- libsq/core/ioz/checksum/checksum.go | 18 +++-- libsq/core/ioz/ioz_test.go | 7 ++ libsq/core/ioz/lockfile/lockfile.go | 18 +++-- libsq/core/ioz/lockfile/lockfile_test.go | 7 +- libsq/core/options/options.go | 4 +- libsq/driver/grips.go | 84 ++++-------------------- libsq/source/detect.go | 16 ++++- libsq/source/files.go | 9 ++- libsq/source/source.go | 5 +- testh/testh.go | 2 - 17 files changed, 77 insertions(+), 112 deletions(-) diff --git a/README.md b/README.md index 86e8ee3df..6a7add9e8 100644 --- a/README.md +++ b/README.md @@ -324,6 +324,7 @@ See [CHANGELOG.md](./CHANGELOG.md). - The [`log.devmode`](https://sq.io/docs/config#logdevmode) log format is derived from [`lmittmann/tint`](https://github.com/lmittmann/tint). - [`djherbis/fscache`](https://github.com/djherbis/fscache) is used for caching. +- A forked version of lockfile ## Similar, related, or noteworthy projects diff --git a/cli/cmd_inspect.go b/cli/cmd_inspect.go index 26aae5deb..3a6b4789d 100644 --- a/cli/cmd_inspect.go +++ b/cli/cmd_inspect.go @@ -125,9 +125,7 @@ func execInspect(cmd *cobra.Command, args []string) error { return err } - log.Error("before open") grip, err := ru.Grips.Open(ctx, src) - log.Error("after open") if err != nil { return errz.Wrapf(err, "failed to inspect %s", src.Handle) } @@ -205,8 +203,6 @@ func execInspect(cmd *cobra.Command, args []string) error { overviewOnly := cmdFlagIsSetTrue(cmd, flag.InspectOverview) - log.Debug("get source metadata") - srcMeta, err := grip.SourceMetadata(ctx, overviewOnly) if err != nil { return errz.Wrapf(err, "failed to read %s source metadata", src.Handle) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index ec9638f9f..39053e03a 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -55,8 +55,8 @@ func newXLockSrcCmd() *cobra.Command { done := make(chan struct{}) go func() { // Wait for ENTER on stdin - buf := bufio.NewReader(os.Stdin) - fmt.Print(" > ") + buf := bufio.NewReader(ru.Stdin) + fmt.Fprint(ru.Out, " > ") _, _ = buf.ReadBytes('\n') close(done) }() diff --git a/cli/run.go b/cli/run.go index a2425ea2f..61d5b4a6f 100644 --- a/cli/run.go +++ b/cli/run.go @@ -156,7 +156,6 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { // because databases could depend upon the existence of // files (such as a sqlite db file). ru.Cleanup.AddE(ru.Files.Close) - ru.Files.AddDriverDetectors(source.DetectMagicNumber) ru.DriverRegistry = driver.NewRegistry(log) dr := ru.DriverRegistry diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index 756caf267..994ed31b4 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -89,7 +89,6 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, er return nil, err } - log.Error("open ingest done", lga.Err, err) return g, nil } @@ -157,11 +156,7 @@ func (p *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Tab // SourceMetadata implements driver.Grip. func (p *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - log := lg.FromContext(ctx) - - log.Debug("before impl.SourceMetadata") md, err := p.impl.SourceMetadata(ctx, noSchema) - log.Debug("after impl.SourceMetadata", lga.Err, err) if err != nil { return nil, err } diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index b1c5c741b..83a4045a3 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -141,7 +141,6 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, x continue } - time.Sleep(time.Millisecond * 500) // FIXME: delete if err = ingestSheetToTable(ctx, destGrip, sheetTbls[i]); err != nil { return err } diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index a5e03e915..0298ae927 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -4,12 +4,13 @@ import ( "database/sql" "errors" "fmt" - "github.com/neilotoole/sq/libsq/core/stringz" "io/fs" "net/url" "os" "testing" + "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/stretchr/testify/require" "github.com/neilotoole/slogt" diff --git a/libsq/core/ioz/checksum/checksum.go b/libsq/core/ioz/checksum/checksum.go index 40c7c28e1..d90ac9c2c 100644 --- a/libsq/core/ioz/checksum/checksum.go +++ b/libsq/core/ioz/checksum/checksum.go @@ -3,8 +3,8 @@ package checksum import ( "bufio" "bytes" - "crypto/sha256" "fmt" + "hash/crc32" "io" "net/http" "os" @@ -14,14 +14,20 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" ) +// Hash returns the hash of b as a hex string. +func Hash(b []byte) string { + sum := crc32.ChecksumIEEE(b) + return fmt.Sprintf("%x", sum) +} + // Checksum is a checksum of a file. type Checksum string // Write appends a checksum line to w, including // a newline. The typical format is: // -// -// da1f14c16c09bebbc452108d9ab193541f2e96515aefcb7745fee5197c343106 file.txt +// +// 3610a686 file.txt // // However, the checksum be any string value. Use ForFile to calculate // a checksum, and Read to read this format. @@ -103,7 +109,7 @@ func ForFile(path string) (Checksum, error) { buf.WriteString(strconv.FormatUint(uint64(fi.Mode()), 10)) buf.WriteString(strconv.FormatBool(fi.IsDir())) - sum := sha256.Sum256(buf.Bytes()) + sum := Hash(buf.Bytes()) return Checksum(fmt.Sprintf("%x", sum)), nil } @@ -129,7 +135,7 @@ func ForHTTPHeader(u string, header http.Header) Checksum { } } - sum := sha256.Sum256(buf.Bytes()) + sum := Hash(buf.Bytes()) return Checksum(fmt.Sprintf("%x", sum)) } @@ -159,6 +165,6 @@ func ForHTTPResponse(resp *http.Response) Checksum { } } - sum := sha256.Sum256(buf.Bytes()) + sum := Hash(buf.Bytes()) return Checksum(fmt.Sprintf("%x", sum)) } diff --git a/libsq/core/ioz/ioz_test.go b/libsq/core/ioz/ioz_test.go index 4f5ecc63e..a11a0ac10 100644 --- a/libsq/core/ioz/ioz_test.go +++ b/libsq/core/ioz/ioz_test.go @@ -15,6 +15,13 @@ import ( "github.com/neilotoole/sq/libsq/core/ioz/checksum" ) +func TestHash(t *testing.T) { + + got := checksum.Hash([]byte("hello")) + t.Log(got) + assert.Equal(t, "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d", got) +} + func TestMarshalYAML(t *testing.T) { m := map[string]any{ "hello": `sqlserver://sakila:p_ss"**W0rd@222.75.174.219?database=sakila`, diff --git a/libsq/core/ioz/lockfile/lockfile.go b/libsq/core/ioz/lockfile/lockfile.go index aa123603f..d245a7a88 100644 --- a/libsq/core/ioz/lockfile/lockfile.go +++ b/libsq/core/ioz/lockfile/lockfile.go @@ -3,10 +3,12 @@ package lockfile import ( "context" - "github.com/neilotoole/sq/libsq/core/ioz" + "errors" "path/filepath" "time" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/nightlyone/lockfile" "github.com/neilotoole/sq/libsq/core/errz" @@ -53,8 +55,6 @@ func (l Lockfile) Lock(ctx context.Context, timeout time.Duration) error { err := retry.Do(ctx, timeout, func() error { - log.Debug("Attempting to acquire pid lock", lga.Attempts, attempts) - err := lockfile.Lockfile(l).TryLock() attempts++ if err == nil { @@ -62,7 +62,7 @@ func (l Lockfile) Lock(ctx context.Context, timeout time.Duration) error { return nil } - log.Debug("Failed to acquire pid lock", lga.Attempts, lga.Err, err) + // log.Debug("Failed to acquire pid lock, may retry", lga.Attempts, attempts, lga.Err, err) return err }, errz.IsType[lockfile.TemporaryError], @@ -70,13 +70,17 @@ func (l Lockfile) Lock(ctx context.Context, timeout time.Duration) error { elapsed := time.Since(start) if err != nil { - log.Warn("Failed to acquire pid lock", + log.Error("Failed to acquire pid lock", lga.Attempts, attempts, lga.Elapsed, elapsed, lga.Err, err, ) - return errz.Wrapf(err, "failed to acquire pid lock: %d attempts in %s: %s", - attempts, time.Since(start), l) + + if errors.Is(err, lockfile.ErrBusy) { + return errz.Errorf("locked by other process") + } + + return errz.Wrapf(err, "acquire lock") } return nil diff --git a/libsq/core/ioz/lockfile/lockfile_test.go b/libsq/core/ioz/lockfile/lockfile_test.go index 086011e21..bf554b5a9 100644 --- a/libsq/core/ioz/lockfile/lockfile_test.go +++ b/libsq/core/ioz/lockfile/lockfile_test.go @@ -2,13 +2,14 @@ package lockfile_test import ( "context" + "path/filepath" + "testing" + "time" + "github.com/neilotoole/slogt" "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" "github.com/stretchr/testify/require" - "path/filepath" - "testing" - "time" ) // FIXME: Duh, this can't work, because we're in the same pid. diff --git a/libsq/core/options/options.go b/libsq/core/options/options.go index 3a8d41877..2bd5dedd5 100644 --- a/libsq/core/options/options.go +++ b/libsq/core/options/options.go @@ -15,8 +15,8 @@ package options import ( "bytes" "context" - "crypto/sha256" "fmt" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "log/slog" "slices" "sync" @@ -193,7 +193,7 @@ func (o Options) Hash() string { v := o[k] buf.WriteString(fmt.Sprintf("%v", v)) } - sum := sha256.Sum256(buf.Bytes()) + sum := checksum.Hash(buf.Bytes()) return fmt.Sprintf("%x", sum) } diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 319232cbd..32f01cba7 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -2,15 +2,12 @@ package driver import ( "context" - "errors" "log/slog" "path/filepath" "strings" "sync" "time" - "github.com/nightlyone/lockfile" - "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" @@ -19,7 +16,6 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" - "github.com/neilotoole/sq/libsq/core/retry" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" ) @@ -99,7 +95,7 @@ func (gs *Grips) IsSQLSource(src *source.Source) bool { } func (gs *Grips) getKey(src *source.Source) string { - return src.Handle + "_" + src.Hash() + return src.Handle } func (gs *Grips) doOpen(ctx context.Context, src *source.Source) (Grip, error) { @@ -229,17 +225,21 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, ingestFn func(ctx context.Context, destGrip Grip) error, ) (Grip, error) { log := lg.FromContext(ctx) + log = log.With(lga.Handle, src.Handle) + ctx = lg.NewContext(ctx, log) - lock, err := gs.acquireLock(ctx, src) + lock, err := gs.files.CacheLockFor(src) if err != nil { return nil, err } + + if err = lock.Lock(ctx, time.Second*5); err != nil { + return nil, errz.Wrap(err, "acquire cache lock") + } + defer func() { - log.Debug("About to release cache lock...", "lock", lock) if err = lock.Unlock(); err != nil { - log.Warn("Failed to release cache lock", "lock", lock, lga.Err, err) - } else { - log.Debug("Released cache lock", "lock", lock) + log.Warn("Failed to release cache lock", lga.Lock, lock, lga.Err, err) } }() @@ -248,10 +248,6 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, return nil, err } - if err = ioz.RequireDir(cacheDir); err != nil { - return nil, err - } - log.Debug("Using cache dir", lga.Path, cacheDir) ingestFilePath, err := gs.files.Filepath(ctx, src) @@ -259,10 +255,8 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, return nil, err } - var ( - impl Grip - foundCached bool - ) + var impl Grip + var foundCached bool if impl, foundCached, err = gs.openCachedFor(ctx, src); err != nil { return nil, err } @@ -324,51 +318,6 @@ func (gs *Grips) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checks return srcCacheDir, cacheDB, checksums, nil } -// acquireLock acquires a lock for src. The caller -// is responsible for unlocking the lock, e.g.: -// -// defer lg.WarnIfFuncError(d.log, "failed to unlock cache lock", lock.Unlock) -// -// The lock acquisition process is retried with backoff. -func (gs *Grips) acquireLock(ctx context.Context, src *source.Source) (lockfile.Lockfile, error) { - lock, err := gs.getLockfileFor(src) - if err != nil { - return "", err - } - - err = retry.Do(ctx, time.Second*5, - func() error { - lg.FromContext(ctx).Debug("Attempting to acquire cache lock", lga.Lock, lock) - return lock.TryLock() - }, - func(err error) bool { - var temporaryError lockfile.TemporaryError - return errors.As(err, &temporaryError) - }, - ) - if err != nil { - return "", errz.Wrap(err, "failed to get lock") - } - - lg.FromContext(ctx).Debug("Acquired cache lock", lga.Lock, lock) - return lock, nil -} - -// getLockfileFor returns a lockfile for src. It doesn't -// actually acquire the lock. -func (gs *Grips) getLockfileFor(src *source.Source) (lockfile.Lockfile, error) { - srcCacheDir, _, _, err := gs.getCachePaths(src) - if err != nil { - return "", err - } - - if err = ioz.RequireDir(srcCacheDir); err != nil { - return "", err - } - lockPath := filepath.Join(srcCacheDir, "pid.lock") - return lockfile.New(lockPath) -} - func (gs *Grips) openCachedFor(ctx context.Context, src *source.Source) (Grip, bool, error) { _, cacheDBPath, checksumsPath, err := gs.getCachePaths(src) if err != nil { @@ -399,8 +348,6 @@ func (gs *Grips) openCachedFor(ctx context.Context, src *source.Source) (Grip, b if err != nil { return nil, false, err } - gs.log.Debug("Got srcFilepath for src", - lga.Src, src, lga.Path, srcFilepath) cachedChecksum, ok := mChecksums[srcFilepath] if !ok { @@ -422,15 +369,10 @@ func (gs *Grips) openCachedFor(ctx context.Context, src *source.Source) (Grip, b return nil, false, nil } - backingType, err := gs.files.DriverType(ctx, cacheDBPath) - if err != nil { - return nil, false, err - } - backingSrc := &source.Source{ Handle: src.Handle + "_cached", Location: "sqlite3://" + cacheDBPath, - Type: backingType, + Type: drivertype.Type("sqlite3"), } backingGrip, err := gs.doOpen(ctx, backingSrc) diff --git a/libsq/source/detect.go b/libsq/source/detect.go index 84b86d100..b2fd53f9a 100644 --- a/libsq/source/detect.go +++ b/libsq/source/detect.go @@ -38,6 +38,7 @@ func (fs *Files) AddDriverDetectors(detectFns ...DriverDetectFunc) { } // DriverType returns the driver type of loc. +// This may result in loading files into the cache. func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, error) { log := lg.FromContext(ctx).With(lga.Loc, loc) ploc, err := parseLoc(loc) @@ -61,6 +62,8 @@ func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, e } } + // FIXME: We really should try to be smarter here, esp with sqlite files. + fs.mu.Lock() defer fs.mu.Unlock() // Fall back to the byte detectors @@ -83,15 +86,22 @@ func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Typ log := lg.FromContext(ctx).With(lga.Loc, loc) start := time.Now() + openFn := func(ctx context.Context) (io.ReadCloser, error) { + return fs.newReader(ctx, loc) + } + + // We do the magic number first, because it's so fast. + detected, score, err := DetectMagicNumber(ctx, openFn) + if err == nil && score >= 1.0 { + return detected, true, nil + } + type result struct { typ drivertype.Type score float32 } resultCh := make(chan result, len(fs.detectFns)) - openFn := func(ctx context.Context) (io.ReadCloser, error) { - return fs.newReader(ctx, loc) - } select { case <-ctx.Done(): diff --git a/libsq/source/files.go b/libsq/source/files.go index 40328c98f..ef7e2e250 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "strconv" + "strings" "sync" "time" @@ -63,6 +64,8 @@ type Files struct { // NewFiles returns a new Files instance. If c is nil, http.DefaultClient is // used. If cleanFscache is true, the fscache is cleaned on Files.Close. func NewFiles(ctx context.Context, c *http.Client, tmpDir, cacheDir string, cleanFscache bool) (*Files, error) { + log := lg.FromContext(ctx) + log.Debug("Creating new Files instance", "tmp_dir", tmpDir, "cache_dir", cacheDir) if tmpDir == "" { return nil, errz.Errorf("tmpDir is empty") } @@ -244,7 +247,11 @@ func (fs *Files) addStdin(ctx context.Context, handle string, f *os.File) error // Do not add stdin via this function; instead use addStdin. func (fs *Files) addRegularFile(ctx context.Context, f *os.File, key string) (fscache.ReadAtCloser, error) { log := lg.FromContext(ctx) - log.Debug("Adding file", lga.Key, key, lga.Path, f.Name()) + log.Debug("Adding regular file", lga.Key, key, lga.Path, f.Name()) + + if strings.Contains(f.Name(), "cached.db") { + log.Error("oh no, shouldn't be happening") + } defer lg.WarnIfCloseError(log, lgm.CloseFileReader, f) diff --git a/libsq/source/source.go b/libsq/source/source.go index 39a73085a..5b8c28464 100644 --- a/libsq/source/source.go +++ b/libsq/source/source.go @@ -3,8 +3,8 @@ package source import ( "bytes" - "crypto/sha256" "fmt" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "log/slog" "net/url" "strings" @@ -106,9 +106,8 @@ func (s *Source) Hash() string { buf.WriteString(s.Location) buf.WriteString(s.Catalog) buf.WriteString(s.Schema) - buf.WriteString(s.Options.Hash()) - sum := sha256.Sum256(buf.Bytes()) + sum := checksum.Hash(buf.Bytes()) return fmt.Sprintf("%x", sum) } diff --git a/testh/testh.go b/testh/testh.go index 0e48f2949..b7f57c07c 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -170,8 +170,6 @@ func (h *Helper) init() { assert.NoError(h.T, err) }) - h.files.AddDriverDetectors(source.DetectMagicNumber) - h.grips = driver.NewGrips(log, h.registry, h.files, sqlite3.NewScratchSource) h.Cleanup.AddC(h.grips) From 7d18e562e4e68a0be31c98cb0bf2130a0416ea68 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 8 Dec 2023 09:21:19 -0700 Subject: [PATCH 073/195] Cleanup - about to add progress NewTimeoutWaiter --- cli/cmd_cache.go | 37 +--------- cli/options.go | 3 + cli/run.go | 7 +- drivers/csv/ingest.go | 2 + libsq/core/ioz/checksum/checksum.go | 17 +++-- libsq/core/ioz/checksum/checksum_test.go | 14 ++++ libsq/core/ioz/ioz.go | 16 ++-- libsq/core/ioz/ioz_test.go | 7 -- libsq/core/options/opt.go | 5 ++ libsq/core/options/options.go | 20 ----- libsq/driver/ingest.go | 3 + libsq/source/cache.go | 50 ++++++++++++- libsq/source/download.go | 26 ++++++- libsq/source/files.go | 93 ++++++++++++++++++++---- libsq/source/source.go | 20 ----- testh/testh.go | 20 +++-- 16 files changed, 213 insertions(+), 127 deletions(-) create mode 100644 libsq/core/ioz/checksum/checksum_test.go diff --git a/cli/cmd_cache.go b/cli/cmd_cache.go index 76a22b6cb..285380a53 100644 --- a/cli/cmd_cache.go +++ b/cli/cmd_cache.go @@ -1,18 +1,13 @@ package cli import ( - "os" - "path/filepath" - "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/run" - "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" ) @@ -114,36 +109,8 @@ func newCacheClearCmd() *cobra.Command { } func execCacheClear(cmd *cobra.Command, _ []string) error { - log := lg.FromContext(cmd.Context()) - cacheDir := source.DefaultCacheDir() - if !ioz.DirExists(cacheDir) { - return nil - } - - // Instead of directly deleting the existing cache dir, we first - // move it to /tmp, and then try to delete it. This should probably - // help with the situation where another sq instance has an open pid - // lock in the cache dir. - - tmpDir := source.DefaultTempDir() - if err := ioz.RequireDir(tmpDir); err != nil { - return errz.Wrap(err, "cache clear") - } - relocateDir := filepath.Join(tmpDir, "dead_cache_"+stringz.Uniq8()) - if err := os.Rename(cacheDir, relocateDir); err != nil { - return errz.Wrap(err, "cache clear: relocate") - } - - if err := os.RemoveAll(relocateDir); err != nil { - log.Warn("Could not delete relocated cache dir", lga.Path, relocateDir, lga.Err, err) - } - - // Recreate the cache dir. - if err := ioz.RequireDir(cacheDir); err != nil { - return errz.Wrap(err, "cache clear") - } - - return nil + ru := run.FromContext(cmd.Context()) + return ru.Files.CacheClear(cmd.Context()) } func newCacheTreeCmd() *cobra.Command { diff --git a/cli/options.go b/cli/options.go index 1a5365432..37b9c3b8c 100644 --- a/cli/options.go +++ b/cli/options.go @@ -118,6 +118,7 @@ func applySourceOptions(cmd *cobra.Command, src *source.Source) error { defaultOpts = options.Options{} } + // FIXME: This should only apply source options? flagOpts, err := getOptionsFromFlags(cmd.Flags(), ru.OptionsRegistry) if err != nil { return err @@ -174,6 +175,8 @@ func RegisterDefaultOpts(reg *options.Registry) { OptLogDevMode, OptDiffNumLines, OptDiffDataFormat, + source.OptHTTPPingTimeout, + source.OptHTTPSkipVerify, driver.OptConnMaxOpen, driver.OptConnMaxIdle, driver.OptConnMaxIdleTime, diff --git a/cli/run.go b/cli/run.go index 61d5b4a6f..515af4e5b 100644 --- a/cli/run.go +++ b/cli/run.go @@ -25,7 +25,6 @@ import ( "github.com/neilotoole/sq/drivers/xlsx" "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/slogbuf" @@ -141,10 +140,8 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { var err error if ru.Files == nil { - // TODO: The timeout/ssl vals should really come from options. - c := ioz.NewHTTPClient(0, true) - - ru.Files, err = source.NewFiles(ctx, c, source.DefaultTempDir(), source.DefaultCacheDir(), true) + ru.Files, err = source.NewFiles(ctx, ru.OptionsRegistry, + source.DefaultTempDir(), source.DefaultCacheDir(), true) if err != nil { lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run) return err diff --git a/drivers/csv/ingest.go b/drivers/csv/ingest.go index b4cde9009..3b8441603 100644 --- a/drivers/csv/ingest.go +++ b/drivers/csv/ingest.go @@ -33,6 +33,7 @@ var OptEmptyAsNull = options.NewBool( `When true, empty CSV fields are treated as NULL. When false, the zero value for that type is used, e.g. empty string or 0.`, options.TagSource, + options.TagIngestMutate, "csv", ) @@ -47,6 +48,7 @@ var OptDelim = options.NewString( `Delimiter to use for CSV files. Default is "comma". Possible values are: comma, space, pipe, tab, colon, semi, period.`, options.TagSource, + options.TagIngestMutate, "csv", ) diff --git a/libsq/core/ioz/checksum/checksum.go b/libsq/core/ioz/checksum/checksum.go index d90ac9c2c..a206599a6 100644 --- a/libsq/core/ioz/checksum/checksum.go +++ b/libsq/core/ioz/checksum/checksum.go @@ -3,6 +3,7 @@ package checksum import ( "bufio" "bytes" + "crypto/rand" "fmt" "hash/crc32" "io" @@ -20,6 +21,13 @@ func Hash(b []byte) string { return fmt.Sprintf("%x", sum) } +// Rand returns a random checksum. +func Rand() string { + b := make([]byte, 128) + _, _ = rand.Read(b) + return Hash(b) +} + // Checksum is a checksum of a file. type Checksum string @@ -109,8 +117,7 @@ func ForFile(path string) (Checksum, error) { buf.WriteString(strconv.FormatUint(uint64(fi.Mode()), 10)) buf.WriteString(strconv.FormatBool(fi.IsDir())) - sum := Hash(buf.Bytes()) - return Checksum(fmt.Sprintf("%x", sum)), nil + return Checksum(Hash(buf.Bytes())), nil } // ForHTTPHeader returns a checksum generated from URL u and @@ -135,8 +142,7 @@ func ForHTTPHeader(u string, header http.Header) Checksum { } } - sum := Hash(buf.Bytes()) - return Checksum(fmt.Sprintf("%x", sum)) + return Checksum(Hash(buf.Bytes())) } // ForHTTPResponse returns a checksum generated from the response's @@ -165,6 +171,5 @@ func ForHTTPResponse(resp *http.Response) Checksum { } } - sum := Hash(buf.Bytes()) - return Checksum(fmt.Sprintf("%x", sum)) + return Checksum(Hash(buf.Bytes())) } diff --git a/libsq/core/ioz/checksum/checksum_test.go b/libsq/core/ioz/checksum/checksum_test.go new file mode 100644 index 000000000..9d9dbec59 --- /dev/null +++ b/libsq/core/ioz/checksum/checksum_test.go @@ -0,0 +1,14 @@ +package checksum_test + +import ( + "testing" + + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/stretchr/testify/assert" +) + +func TestHash(t *testing.T) { + got := checksum.Hash([]byte("hello world")) + t.Log(got) + assert.Equal(t, "d4a1185", got) +} diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index b073949d2..61aafa5a1 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -389,13 +389,10 @@ func PrintTree(w io.Writer, loc string, showSize, colorize bool) error { return nil } -// NewHTTPClient returns a new HTTP client with the specified timeout. -// A timeout of zero means no timeout. If insecureSkipVerify is true, the -// client will skip TLS verification. -// -// REVISIT: Would it be better to just not set a timeout, and instead -// use context.WithTimeout for each request? -func NewHTTPClient(timeout time.Duration, insecureSkipVerify bool) *http.Client { +// NewHTTPClient returns a new HTTP client with no client-wide timeout. +// If a timeout is needed, use a [context.WithTimeout] for each request. +// If insecureSkipVerify is true, the client will skip TLS verification. +func NewHTTPClient(insecureSkipVerify bool) *http.Client { client := *http.DefaultClient var tr *http.Transport @@ -406,6 +403,9 @@ func NewHTTPClient(timeout time.Duration, insecureSkipVerify bool) *http.Client } if tr.TLSClientConfig == nil { + // We allow tls.VersionTLS10, even though it's not considered + // secure these days. Ultimately this could become a config + // option. tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS10} //nolint:gosec } else { tr.TLSClientConfig = tr.TLSClientConfig.Clone() @@ -413,7 +413,7 @@ func NewHTTPClient(timeout time.Duration, insecureSkipVerify bool) *http.Client tr.TLSClientConfig.InsecureSkipVerify = insecureSkipVerify - client.Timeout = timeout + client.Timeout = 0 client.Transport = tr return &client diff --git a/libsq/core/ioz/ioz_test.go b/libsq/core/ioz/ioz_test.go index a11a0ac10..4f5ecc63e 100644 --- a/libsq/core/ioz/ioz_test.go +++ b/libsq/core/ioz/ioz_test.go @@ -15,13 +15,6 @@ import ( "github.com/neilotoole/sq/libsq/core/ioz/checksum" ) -func TestHash(t *testing.T) { - - got := checksum.Hash([]byte("hello")) - t.Log(got) - assert.Equal(t, "aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d", got) -} - func TestMarshalYAML(t *testing.T) { m := map[string]any{ "hello": `sqlserver://sakila:p_ss"**W0rd@222.75.174.219?database=sakila`, diff --git a/libsq/core/options/opt.go b/libsq/core/options/opt.go index b4a9770fd..0b13ed410 100644 --- a/libsq/core/options/opt.go +++ b/libsq/core/options/opt.go @@ -24,6 +24,11 @@ const ( // TagOutput indicates the Opt is related to output/formatting. TagOutput = "output" + + // TagIngestMutate indicates the Opt may result in mutated data, particularly + // during ingestion. This tag is significant in that its value may affect + // data realization, and thus affect program aspects such as caching behavior. + TagIngestMutate = "mutate" ) // Opt is an option type. Concrete impls exist for various types, diff --git a/libsq/core/options/options.go b/libsq/core/options/options.go index 2bd5dedd5..282933742 100644 --- a/libsq/core/options/options.go +++ b/libsq/core/options/options.go @@ -13,10 +13,8 @@ package options import ( - "bytes" "context" "fmt" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" "log/slog" "slices" "sync" @@ -179,24 +177,6 @@ func (o Options) Clone() Options { return o2 } -// Hash returns a SHA256 hash of o. If o is nil or empty, -// an empty string is returned. -func (o Options) Hash() string { - if len(o) == 0 { - return "" - } - - keys := o.Keys() - buf := bytes.Buffer{} - for _, k := range keys { - buf.WriteString(k) - v := o[k] - buf.WriteString(fmt.Sprintf("%v", v)) - } - sum := checksum.Hash(buf.Bytes()) - return fmt.Sprintf("%x", sum) -} - // Keys returns the sorted set of keys in o. func (o Options) Keys() []string { keys := lo.Keys(o) diff --git a/libsq/driver/ingest.go b/libsq/driver/ingest.go index ce6249785..5b415a0b1 100644 --- a/libsq/driver/ingest.go +++ b/libsq/driver/ingest.go @@ -22,6 +22,7 @@ If not set, the ingester *may* try to detect if the input has a header. Generally it is best to leave this option unset and allow the ingester to detect the header.`, options.TagSource, + options.TagIngestMutate, ) // OptIngestCache specifies whether ingested data is cached or not. @@ -46,6 +47,7 @@ var OptIngestSampleSize = options.NewInt( "Ingest data sample size for type detection", `Specify the number of samples that a detector should take to determine type.`, options.TagSource, + options.TagIngestMutate, ) // OptIngestColRename transforms a column name in ingested data. @@ -79,6 +81,7 @@ For a unique column name, e.g. "first_name" above, ".Recurrence" will be 0. For duplicate column names, ".Recurrence" will be 0 for the first instance, then 1 for the next instance, and so on.`, options.TagSource, + options.TagIngestMutate, ) // MungeIngestColNames transforms ingest data column names, per the template diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 36a0f2fbe..e726a1ed1 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -1,10 +1,16 @@ package source import ( + "bytes" + "fmt" "os" "path/filepath" "strings" + "github.com/neilotoole/sq/libsq/core/options" + + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/stringz" ) @@ -27,12 +33,54 @@ func (fs *Files) CacheDirFor(src *Source) (dir string, err error) { fs.cacheDir, "sources", filepath.Join(strings.Split(strings.TrimPrefix(handle, "@"), "/")...), - src.Hash(), + fs.sourceHash(src), ) return dir, nil } +// sourceHash generates a hash for src. The hash is based on the +// member fields of src, with special handling for src.Options. +// Only the opts that affect data ingestion (options.TagIngestMutate) +// are incorporated in the hash. +func (fs *Files) sourceHash(src *Source) string { + if src == nil { + return "" + } + + buf := bytes.Buffer{} + buf.WriteString(src.Handle) + buf.WriteString(string(src.Type)) + buf.WriteString(src.Location) + buf.WriteString(src.Catalog) + buf.WriteString(src.Schema) + + // FIXME: Revisit this + mUsedKeys := make(map[string]any) + + if src.Options != nil { + keys := src.Options.Keys() + for _, k := range keys { + opt := fs.optRegistry.Get(k) + switch { + case opt == nil, + !opt.IsSet(src.Options), + !opt.HasTag(options.TagIngestMutate): + continue + default: + } + + buf.WriteString(k) + v := src.Options[k] + buf.WriteString(fmt.Sprintf("%v", v)) + mUsedKeys[k] = v + } + } + + sum := checksum.Hash(buf.Bytes()) + return sum +} + // DefaultCacheDir returns the sq cache dir. This is generally // in USER_CACHE_DIR/*/sq, but could also be in TEMP_DIR/*/sq/cache // or similar. It is not guaranteed that the returned dir exists diff --git a/libsq/source/download.go b/libsq/source/download.go index fec7ea36f..51338c6f5 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -13,6 +13,9 @@ import ( "path" "path/filepath" "sync" + "time" + + "github.com/neilotoole/sq/libsq/core/options" "golang.org/x/exp/maps" @@ -27,6 +30,28 @@ import ( "github.com/neilotoole/sq/libsq/source/fetcher" ) +var OptHTTPPingTimeout = options.NewDuration( + "http.ping.timeout", + "", + 0, + time.Second*10, + "HTTP ping timeout duration", + `How long to wait for initial response from HTTP endpoint before +timeout occurs. Long-running operations, such as HTTP file downloads, are +not affected by this option. Example: 500ms or 3s.`, + options.TagSource, +) + +var OptHTTPSkipVerify = options.NewBool( + "http.skip-verify", + "", + false, + 0, + false, + "Skip HTTPS TLS verification", + "Skip HTTPS TLS verification. Useful when downloading against self-signed certs.", +) + func newDownloader(c *http.Client, cacheDir, url string) *downloader { return &downloader{ c: c, @@ -71,7 +96,6 @@ func newDownloader(c *http.Client, cacheDir, url string) *downloader { // stored files before proceeding. Likewise, if the download fails, the stored // files are wiped, to prevent a partial download from being used. type downloader struct { - // FIXME: Use a client that doesn't require SSL? (see fetcher) c *http.Client mu sync.Mutex cacheDir string diff --git a/libsq/source/files.go b/libsq/source/files.go index ef7e2e250..50993f637 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -4,7 +4,6 @@ import ( "context" "io" "log/slog" - "net/http" "net/url" "os" "path/filepath" @@ -13,6 +12,10 @@ import ( "sync" "time" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + + "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/fscache" "github.com/neilotoole/sq/libsq/core/cleanup" @@ -40,12 +43,12 @@ import ( // if we're reading long-running pipe from stdin). This entire thing // needs to be revisited. Maybe Files even becomes a fs.FS. type Files struct { - mu sync.Mutex - log *slog.Logger - cacheDir string - tempDir string - clnup *cleanup.Cleanup - httpClient *http.Client + mu sync.Mutex + log *slog.Logger + cacheDir string + tempDir string + clnup *cleanup.Cleanup + optRegistry *options.Registry // fscache is used to cache files, providing convenient access // to multiple readers via Files.newReader. @@ -61,9 +64,11 @@ type Files struct { detectFns []DriverDetectFunc } -// NewFiles returns a new Files instance. If c is nil, http.DefaultClient is -// used. If cleanFscache is true, the fscache is cleaned on Files.Close. -func NewFiles(ctx context.Context, c *http.Client, tmpDir, cacheDir string, cleanFscache bool) (*Files, error) { +// NewFiles returns a new Files instance. If cleanFscache is true, the fscache +// is cleaned on Files.Close. +func NewFiles(ctx context.Context, optReg *options.Registry, + tmpDir, cacheDir string, cleanFscache bool, +) (*Files, error) { log := lg.FromContext(ctx) log.Debug("Creating new Files instance", "tmp_dir", tmpDir, "cache_dir", cacheDir) if tmpDir == "" { @@ -73,12 +78,12 @@ func NewFiles(ctx context.Context, c *http.Client, tmpDir, cacheDir string, clea return nil, errz.Errorf("cacheDir is empty") } - if c == nil { - c = http.DefaultClient + if optReg == nil { + optReg = &options.Registry{} } fs := &Files{ - httpClient: c, + optRegistry: optReg, cacheDir: cacheDir, fscacheEntryMetas: make(map[string]*fscacheEntryMeta), tempDir: tmpDir, @@ -90,7 +95,12 @@ func NewFiles(ctx context.Context, c *http.Client, tmpDir, cacheDir string, clea // on cleanup (unless something bad happens and sq doesn't // get a chance to clean up). But, why take the chance; we'll just give // fcache a unique dir each time. - fscacheTmpDir := filepath.Join(cacheDir, "fscache", strconv.Itoa(os.Getpid())+"_"+stringz.Uniq32()) + fscacheTmpDir := filepath.Join( + cacheDir, + "fscache", + strconv.Itoa(os.Getpid())+"_"+checksum.Rand(), + ) + if err := ioz.RequireDir(fscacheTmpDir); err != nil { return nil, errz.Err(err) } @@ -108,7 +118,7 @@ func NewFiles(ctx context.Context, c *http.Client, tmpDir, cacheDir string, clea fs.clnup.AddE(fs.fscache.Clean) // REVISIT: We could automatically sweep the cache dir on Close? - // fs.clnup.Add(func() { fs.sweepCacheDir(ctx) }) + // fs.clnup.Add(func() { fs.CacheSweep(ctx) }) return fs, nil } @@ -173,6 +183,11 @@ func (fs *Files) AddStdin(ctx context.Context, f *os.File) error { fs.mu.Lock() defer fs.mu.Unlock() + // FIXME: This might be the spot where we can add the cleanup + // for the stdin cache dir, because it should always be deleted + // when sq exits. But, first we probably need to refactor the + // interaction with driver.Grips. + err := fs.addStdin(ctx, StdinHandle, f) // f is closed by addStdin return errz.Wrapf(err, "failed to add %s to fscache", StdinHandle) } @@ -508,7 +523,53 @@ func (fs *Files) CleanupE(fn func() error) { fs.clnup.AddE(fn) } -func (fs *Files) sweepCacheDir(ctx context.Context) { +// CacheClear clears the cache dir. This wipes the entire contents +// of the cache dir, so it should be used with caution. Note that +// this operation is distinct from [Files.CacheSweep]. +func (fs *Files) CacheClear(ctx context.Context) error { + fs.mu.Lock() + defer fs.mu.Unlock() + + log := lg.FromContext(ctx).With(lga.Dir, fs.cacheDir) + log.Debug("Clearing cache dir") + if !ioz.DirExists(fs.cacheDir) { + log.Debug("Cache dir does not exist") + return nil + } + + // Instead of directly deleting the existing cache dir, we first + // move it to /tmp, and then try to delete it. This should probably + // help with the situation where another sq instance has an open pid + // lock in the cache dir. + + tmpDir := DefaultTempDir() + if err := ioz.RequireDir(tmpDir); err != nil { + return errz.Wrap(err, "cache clear") + } + relocateDir := filepath.Join(tmpDir, "dead_cache_"+stringz.Uniq8()) + if err := os.Rename(fs.cacheDir, relocateDir); err != nil { + return errz.Wrap(err, "cache clear: relocate") + } + + if err := os.RemoveAll(relocateDir); err != nil { + log.Warn("Could not delete relocated cache dir", lga.Path, relocateDir, lga.Err, err) + } + + // Recreate the cache dir. + if err := ioz.RequireDir(fs.cacheDir); err != nil { + return errz.Wrap(err, "cache clear") + } + + return nil +} + +// CacheSweep sweeps the cache dir, making a best-effort attempt +// to remove any empty directories. Note that this operation is +// distinct from [Files.CacheClear]. +func (fs *Files) CacheSweep(ctx context.Context) { + fs.mu.Lock() + defer fs.mu.Unlock() + dir := fs.cacheDir log := lg.FromContext(ctx).With(lga.Dir, dir) log.Debug("Sweeping cache dir") diff --git a/libsq/source/source.go b/libsq/source/source.go index 5b8c28464..34eef895f 100644 --- a/libsq/source/source.go +++ b/libsq/source/source.go @@ -2,9 +2,7 @@ package source import ( - "bytes" "fmt" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" "log/slog" "net/url" "strings" @@ -93,24 +91,6 @@ type Source struct { Options options.Options `yaml:"options,omitempty" json:"options,omitempty"` } -// Hash returns an SHA256 hash of all fields of s. The Source.Options -// field is ignored. If s is nil, the empty string is returned. -func (s *Source) Hash() string { - if s == nil { - return "" - } - - buf := bytes.Buffer{} - buf.WriteString(s.Handle) - buf.WriteString(string(s.Type)) - buf.WriteString(s.Location) - buf.WriteString(s.Catalog) - buf.WriteString(s.Schema) - - sum := checksum.Hash(buf.Bytes()) - return fmt.Sprintf("%x", sum) -} - // LogValue implements slog.LogValuer. func (s *Source) LogValue() slog.Value { if s == nil { diff --git a/testh/testh.go b/testh/testh.go index b7f57c07c..c342547a2 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -158,10 +158,14 @@ func NewWith(t testing.TB, handle string) (*Helper, *source.Source, driver.SQLDr func (h *Helper) init() { h.initOnce.Do(func() { log := h.Log + + optRegistry := &options.Registry{} + cli.RegisterDefaultOpts(optRegistry) h.registry = driver.NewRegistry(log) var err error - h.files, err = source.NewFiles(h.Context, nil, tu.TempDir(h.T), tu.CacheDir(h.T), true) + h.files, err = source.NewFiles(h.Context, optRegistry, + tu.TempDir(h.T), tu.CacheDir(h.T), true) require.NoError(h.T, err) h.Cleanup.Add(func() { @@ -199,13 +203,13 @@ func (h *Helper) init() { h.addUserDrivers() h.run = &run.Run{ - Stdin: os.Stdin, - Out: os.Stdout, - ErrOut: os.Stdin, - Config: config.New(), - ConfigStore: config.DiscardStore{}, - OptionsRegistry: &options.Registry{}, - DriverRegistry: h.registry, + Stdin: os.Stdin, + Out: os.Stdout, + ErrOut: os.Stdin, + Config: config.New(), + ConfigStore: config.DiscardStore{}, + + DriverRegistry: h.registry, } }) } From b39ef90435070dac5e678e26eda64476270c86ea Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 8 Dec 2023 10:02:41 -0700 Subject: [PATCH 074/195] Cleaned up progress package --- cli/cli.go | 1 + cli/cmd_x.go | 123 +++++++----- libsq/core/errz/errz_test.go | 3 +- libsq/core/ioz/checksum/checksum_test.go | 3 +- libsq/core/ioz/lockfile/lockfile.go | 3 +- libsq/core/ioz/lockfile/lockfile_test.go | 4 +- libsq/core/progress/bars.go | 117 ++++++++++++ libsq/core/progress/progress.go | 228 +++-------------------- libsq/core/progress/style.go | 85 +++++++++ libsq/source/cache.go | 6 +- libsq/source/download.go | 3 +- libsq/source/files.go | 6 +- 12 files changed, 312 insertions(+), 270 deletions(-) create mode 100644 libsq/core/progress/bars.go create mode 100644 libsq/core/progress/style.go diff --git a/cli/cli.go b/cli/cli.go index d8bd77714..fccca7a25 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -239,6 +239,7 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { xCmd := addCmd(ru, rootCmd, newXCmd()) addCmd(ru, xCmd, newXLockSrcCmd()) + addCmd(ru, xCmd, newXDevTestCmd()) return rootCmd } diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 39053e03a..ecb4133b4 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -11,10 +11,13 @@ import ( "github.com/neilotoole/sq/cli/run" ) +// newXCmd returns the root "x" command, which is the container +// for a set of hidden commands that are useful for development. +// The x commands are not intended for end users. func newXCmd() *cobra.Command { cmd := &cobra.Command{ Use: "x", - Short: "Run hidden/test commands", + Short: "Run hidden dev/test commands", Hidden: true, } @@ -28,55 +31,77 @@ func newXLockSrcCmd() *cobra.Command { Hidden: true, Args: cobra.ExactArgs(1), ValidArgsFunction: completeHandle(1), - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - ru := run.FromContext(ctx) - src, err := ru.Config.Collection.Get(args[0]) - if err != nil { - return err - } - - timeout := time.Minute * 20 - lock, err := ru.Files.CacheLockFor(src) - if err != nil { - return err - } - fmt.Fprintf(ru.Out, "Locking cache for source %s with timeout %s for %q [%d]\n\n %s\n\n", - src.Handle, timeout, os.Args[0], os.Getpid(), lock) - - err = lock.Lock(ctx, timeout) - if err != nil { - return err - } - - fmt.Fprintf(ru.Out, "Cache lock acquired for %s\n", src.Handle) - fmt.Fprintln(ru.Out, "Press ENTER to release lock and exit.") - - done := make(chan struct{}) - go func() { - // Wait for ENTER on stdin - buf := bufio.NewReader(ru.Stdin) - fmt.Fprint(ru.Out, " > ") - _, _ = buf.ReadBytes('\n') - close(done) - }() - - select { - case <-done: - fmt.Fprintln(ru.Out, "ENTER received, releasing lock") - case <-ctx.Done(): - fmt.Fprintln(ru.Out, "\nContext done, releasing lock") - } - - fmt.Fprintf(ru.Out, "Releasing cache lock for %s\n", src.Handle) - if err = lock.Unlock(); err != nil { - return err - } - - fmt.Fprintf(ru.Out, "Cache lock released for %s\n", src.Handle) - return nil - }, + RunE: execXLockSrcCmd, + Example: ` $ sq x lock-src-cache @sakila`, } return cmd } + +func execXLockSrcCmd(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + ru := run.FromContext(ctx) + src, err := ru.Config.Collection.Get(args[0]) + if err != nil { + return err + } + + timeout := time.Minute * 20 + lock, err := ru.Files.CacheLockFor(src) + if err != nil { + return err + } + fmt.Fprintf(ru.Out, "Locking cache for source %s with timeout %s for %q [%d]\n\n %s\n\n", + src.Handle, timeout, os.Args[0], os.Getpid(), lock) + + err = lock.Lock(ctx, timeout) + if err != nil { + return err + } + + fmt.Fprintf(ru.Out, "Cache lock acquired for %s\n", src.Handle) + fmt.Fprintln(ru.Out, "Press ENTER to release lock and exit.") + + done := make(chan struct{}) + go func() { + // Wait for ENTER on stdin + buf := bufio.NewReader(ru.Stdin) + fmt.Fprint(ru.Out, " > ") + _, _ = buf.ReadBytes('\n') + close(done) + }() + + select { + case <-done: + fmt.Fprintln(ru.Out, "ENTER received, releasing lock") + case <-ctx.Done(): + fmt.Fprintln(ru.Out, "\nContext done, releasing lock") + } + + fmt.Fprintf(ru.Out, "Releasing cache lock for %s\n", src.Handle) + if err = lock.Unlock(); err != nil { + return err + } + + fmt.Fprintf(ru.Out, "Cache lock released for %s\n", src.Handle) + return nil +} + +func newXDevTestCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "dev-test", + Short: "Execute some dev test code", + Hidden: true, + RunE: execXDevTestCmd, + Example: ` $ sq x dev-test`, + } + + return cmd +} + +func execXDevTestCmd(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + ru := run.FromContext(ctx) + _ = ru + return nil +} diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 0298ae927..838b106fb 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -9,14 +9,13 @@ import ( "os" "testing" - "github.com/neilotoole/sq/libsq/core/stringz" - "github.com/stretchr/testify/require" "github.com/neilotoole/slogt" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/stringz" ) func TestIs(t *testing.T) { diff --git a/libsq/core/ioz/checksum/checksum_test.go b/libsq/core/ioz/checksum/checksum_test.go index 9d9dbec59..09bf59b65 100644 --- a/libsq/core/ioz/checksum/checksum_test.go +++ b/libsq/core/ioz/checksum/checksum_test.go @@ -3,8 +3,9 @@ package checksum_test import ( "testing" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/stretchr/testify/assert" + + "github.com/neilotoole/sq/libsq/core/ioz/checksum" ) func TestHash(t *testing.T) { diff --git a/libsq/core/ioz/lockfile/lockfile.go b/libsq/core/ioz/lockfile/lockfile.go index d245a7a88..af267a555 100644 --- a/libsq/core/ioz/lockfile/lockfile.go +++ b/libsq/core/ioz/lockfile/lockfile.go @@ -7,11 +7,10 @@ import ( "path/filepath" "time" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/nightlyone/lockfile" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/retry" diff --git a/libsq/core/ioz/lockfile/lockfile_test.go b/libsq/core/ioz/lockfile/lockfile_test.go index bf554b5a9..46225d2e6 100644 --- a/libsq/core/ioz/lockfile/lockfile_test.go +++ b/libsq/core/ioz/lockfile/lockfile_test.go @@ -6,10 +6,12 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" - "github.com/stretchr/testify/require" ) // FIXME: Duh, this can't work, because we're in the same pid. diff --git a/libsq/core/progress/bars.go b/libsq/core/progress/bars.go new file mode 100644 index 000000000..f78ed2679 --- /dev/null +++ b/libsq/core/progress/bars.go @@ -0,0 +1,117 @@ +package progress + +import ( + humanize "github.com/dustin/go-humanize" + "github.com/dustin/go-humanize/english" + mpb "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" + // NewByteCounter returns a new progress bar whose metric is the count + // of bytes processed. If the size is unknown, set arg size to -1. The caller + // is ultimately responsible for calling [Bar.Stop] on the returned Bar. + // However, the returned Bar is also added to the Progress's cleanup list, + // so it will be called automatically when the Progress is shut down, but that + // may be later than the actual conclusion of the Bar's work. +) + +func (p *Progress) NewByteCounter(msg string, size int64) *Bar { + if p == nil { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + var style mpb.BarFillerBuilder + var counter decor.Decorator + var percent decor.Decorator + if size < 0 { + style = spinnerStyle(p.colors.Filler) + counter = decor.Current(decor.SizeB1024(0), "% .1f") + } else { + style = barStyle(p.colors.Filler) + counter = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") + percent = decor.NewPercentage(" %.1f", decor.WCSyncSpace) + percent = colorize(percent, p.colors.Percent) + } + counter = colorize(counter, p.colors.Size) + + return p.newBar(msg, size, style, counter, percent) +} + +// NewUnitCounter returns a new indeterminate bar whose label +// metric is the plural of the provided unit. The caller is ultimately +// responsible for calling [Bar.Stop] on the returned Bar. However, +// the returned Bar is also added to the Progress's cleanup list, so +// it will be called automatically when the Progress is shut down, but that +// may be later than the actual conclusion of the spinner's work. +// +// bar := p.NewUnitCounter("Ingest records", "rec") +// defer bar.Stop() +// +// for i := 0; i < 100; i++ { +// bar.IncrBy(1) +// time.Sleep(100 * time.Millisecond) +// } +// +// This produces output similar to: +// +// Ingesting records ∙∙● 87 recs +// +// Note that the unit arg is automatically pluralized. +func (p *Progress) NewUnitCounter(msg, unit string) *Bar { + if p == nil { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + decorator := decor.Any(func(statistics decor.Statistics) string { + s := humanize.Comma(statistics.Current) + if unit != "" { + s += " " + english.PluralWord(int(statistics.Current), unit, "") + } + return s + }) + decorator = colorize(decorator, p.colors.Size) + + style := spinnerStyle(p.colors.Filler) + + return p.newBar(msg, -1, style, decorator) +} + +// NewUnitTotalCounter returns a new determinate bar whose label +// metric is the plural of the provided unit. The caller is ultimately +// responsible for calling [Bar.Stop] on the returned Bar. However, +// the returned Bar is also added to the Progress's cleanup list, so +// it will be called automatically when the Progress is shut down, but that +// may be later than the actual conclusion of the Bar's work. +// +// This produces output similar to: +// +// Ingesting sheets ∙∙∙∙∙● 4 / 16 sheets +// +// Note that the unit arg is automatically pluralized. +func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { + if p == nil { + return nil + } + + if total <= 0 { + return p.NewUnitCounter(msg, unit) + } + + p.mu.Lock() + defer p.mu.Unlock() + + style := barStyle(p.colors.Filler) + decorator := decor.Any(func(statistics decor.Statistics) string { + s := humanize.Comma(statistics.Current) + " / " + humanize.Comma(statistics.Total) + if unit != "" { + s += " " + english.PluralWord(int(statistics.Current), unit, "") + } + return s + }) + decorator = colorize(decorator, p.colors.Size) + return p.newBar(msg, total, style, decorator) +} diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 3d2840c13..6af6fdd1b 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -22,9 +22,6 @@ import ( "sync/atomic" "time" - humanize "github.com/dustin/go-humanize" - "github.com/dustin/go-humanize/english" - "github.com/fatih/color" mpb "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" @@ -63,26 +60,6 @@ func FromContext(ctx context.Context) *Progress { return nil } -const ( - msgLength = 22 - barWidth = 28 - boxWidth = 64 - refreshRate = 150 * time.Millisecond -) - -// NOTE: The implementation below is wildly more complicated than it should be. -// This is due to a bug in the mpb package, wherein it doesn't fully -// respect the render delay. -// -// https://github.com/vbauerster/mpb/issues/136 -// -// Until that bug is fixed, we have a messy workaround. The gist of it -// is that both the Progress.pc and Bar.bar are lazily initialized. -// The Progress.pc (progress container) is initialized on the first -// call to one of the Progress.NewX methods. The Bar.bar is initialized -// only after the render delay has expired. The details are ugly. -// Hopefully this can all be simplified once the mpb bug is fixed. - // New returns a new Progress instance, which is a container for progress bars. // The returned Progress instance is safe for concurrent use, and all of its // public methods can be safely invoked on a nil Progress. The caller is @@ -129,6 +106,25 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors // The caller is responsible for calling [Progress.Stop] to indicate // completion. type Progress struct { + // The implementation here may seem a bit convoluted. The gist of it is that + // both the Progress.pc and Bar.bar are lazily initialized. The Progress.pc + // (progress container) is initialized on the first call to one of the + // Progress.NewX methods. The Bar.bar is initialized only after the bar's own + // render delay has expired. The details are ugly. + // + // Why not just use the mpb package directly? There are three main reasons: + // + // 1. At the time of creating this package, the mpb package didn't correctly + // honor the render delay. See: https://github.com/vbauerster/mpb/issues/136 + // That bug has since been fixed, but... + // 2. The delayed initialization of the Bar.bar is useful for our purposes. + // In particular, we can set the render delay on a per-bar basis, which is + // not possible with the mpb package (its render delay is per Progress, not + // per Bar). + // 3. Having this wrapper around the mpb package allows us greater + // flexibility, e.g. if we ever want to swap out the mpb package for + // something else. + // mu guards ALL public methods. mu *sync.Mutex @@ -197,115 +193,6 @@ func (p *Progress) Stop() { p.pc.Wait() } -// NewUnitCounter returns a new indeterminate bar whose label -// metric is the plural of the provided unit. The caller is ultimately -// responsible for calling [Bar.Stop] on the returned Bar. However, -// the returned Bar is also added to the Progress's cleanup list, so -// it will be called automatically when the Progress is shut down, but that -// may be later than the actual conclusion of the spinner's work. -// -// bar := p.NewUnitCounter("Ingest records", "rec") -// defer bar.Stop() -// -// for i := 0; i < 100; i++ { -// bar.IncrBy(1) -// time.Sleep(100 * time.Millisecond) -// } -// -// This produces output similar to: -// -// Ingesting records ∙∙● 87 recs -// -// Note that the unit arg is automatically pluralized. -func (p *Progress) NewUnitCounter(msg, unit string) *Bar { - if p == nil { - return nil - } - - p.mu.Lock() - defer p.mu.Unlock() - - decorator := decor.Any(func(statistics decor.Statistics) string { - s := humanize.Comma(statistics.Current) - if unit != "" { - s += " " + english.PluralWord(int(statistics.Current), unit, "") - } - return s - }) - decorator = colorize(decorator, p.colors.Size) - - style := spinnerStyle(p.colors.Filler) - - return p.newBar(msg, -1, style, decorator) -} - -// NewUnitTotalCounter returns a new determinate bar whose label -// metric is the plural of the provided unit. The caller is ultimately -// responsible for calling [Bar.Stop] on the returned Bar. However, -// the returned Bar is also added to the Progress's cleanup list, so -// it will be called automatically when the Progress is shut down, but that -// may be later than the actual conclusion of the Bar's work. -// -// This produces output similar to: -// -// Ingesting sheets ∙∙∙∙∙● 4 / 16 sheets -// -// Note that the unit arg is automatically pluralized. -func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { - if p == nil { - return nil - } - - if total <= 0 { - return p.NewUnitCounter(msg, unit) - } - - p.mu.Lock() - defer p.mu.Unlock() - - style := barStyle(p.colors.Filler) - decorator := decor.Any(func(statistics decor.Statistics) string { - s := humanize.Comma(statistics.Current) + " / " + humanize.Comma(statistics.Total) - if unit != "" { - s += " " + english.PluralWord(int(statistics.Current), unit, "") - } - return s - }) - decorator = colorize(decorator, p.colors.Size) - return p.newBar(msg, total, style, decorator) -} - -// NewByteCounter returns a new progress bar whose metric is the count -// of bytes processed. If the size is unknown, set arg size to -1. The caller -// is ultimately responsible for calling [Bar.Stop] on the returned Bar. -// However, the returned Bar is also added to the Progress's cleanup list, -// so it will be called automatically when the Progress is shut down, but that -// may be later than the actual conclusion of the Bar's work. -func (p *Progress) NewByteCounter(msg string, size int64) *Bar { - if p == nil { - return nil - } - - p.mu.Lock() - defer p.mu.Unlock() - - var style mpb.BarFillerBuilder - var counter decor.Decorator - var percent decor.Decorator - if size < 0 { - style = spinnerStyle(p.colors.Filler) - counter = decor.Current(decor.SizeB1024(0), "% .1f") - } else { - style = barStyle(p.colors.Filler) - counter = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") - percent = decor.NewPercentage(" %.1f", decor.WCSyncSpace) - percent = colorize(percent, p.colors.Percent) - } - counter = colorize(counter, p.colors.Size) - - return p.newBar(msg, size, style, counter, percent) -} - // newBar returns a new Bar. This function must only be called from // inside the mutex. func (p *Progress) newBar(msg string, total int64, @@ -324,7 +211,7 @@ func (p *Progress) newBar(msg string, total int64, } if p.pc == nil { - p.pcInit() // FIXME: delete this + p.pcInit() } if total < 0 { @@ -360,7 +247,7 @@ func (p *Progress) newBar(msg string, total int64, b.incrStash.Store(0) } - b.delayCh = renderDelayBar(b, p.delay) + b.delayCh = barRenderDelay(b, p.delay) p.bars = append(p.bars, b) return b @@ -448,9 +335,9 @@ func (b *Bar) Stop() { b.bar.Wait() } -// renderDelay returns a channel that will be closed after d, +// barRenderDelay returns a channel that will be closed after d, // at which point b will be initialized. -func renderDelayBar(b *Bar, d time.Duration) <-chan struct{} { +func barRenderDelay(b *Bar, d time.Duration) <-chan struct{} { ch := make(chan struct{}) t := time.NewTimer(d) go func() { @@ -462,72 +349,3 @@ func renderDelayBar(b *Bar, d time.Duration) <-chan struct{} { }() return ch } - -func colorize(decorator decor.Decorator, c *color.Color) decor.Decorator { - return decor.Meta(decorator, func(s string) string { - return c.Sprint(s) - }) -} - -// DefaultColors returns the default colors used for the progress bars. -func DefaultColors() *Colors { - return &Colors{ - Message: color.New(color.Faint), - Filler: color.New(color.FgGreen, color.Bold, color.Faint), - Size: color.New(color.Faint), - Percent: color.New(color.FgCyan, color.Faint), - } -} - -// Colors is the set of colors used for the progress bars. -type Colors struct { - Message *color.Color - Filler *color.Color - Size *color.Color - Percent *color.Color -} - -// EnableColor enables or disables color for the progress bars. -func (c *Colors) EnableColor(enable bool) { - if c == nil { - return - } - - if enable { - c.Message.EnableColor() - c.Filler.EnableColor() - c.Size.EnableColor() - c.Percent.EnableColor() - return - } - - c.Message.DisableColor() - c.Filler.DisableColor() - c.Size.DisableColor() - c.Percent.DisableColor() -} - -func spinnerStyle(c *color.Color) mpb.SpinnerStyleComposer { - // REVISIT: maybe use ascii chars only, in case it's a weird terminal? - frames := []string{"∙∙∙", "●∙∙", "●∙∙", "∙●∙", "∙●∙", "∙∙●", "∙∙●", "∙∙∙"} - style := mpb.SpinnerStyle(frames...) - if c != nil { - style = style.Meta(func(s string) string { - return c.Sprint(s) - }) - } - return style -} - -func barStyle(c *color.Color) mpb.BarStyleComposer { - clr := func(s string) string { - return c.Sprint(s) - } - - frames := []string{"∙", "●", "●", "●", "∙"} - return mpb.BarStyle(). - Lbound(" ").Rbound(" "). - Filler("∙").FillerMeta(clr). - Padding(" "). - Tip(frames...).TipMeta(clr) -} diff --git a/libsq/core/progress/style.go b/libsq/core/progress/style.go new file mode 100644 index 000000000..11fe87099 --- /dev/null +++ b/libsq/core/progress/style.go @@ -0,0 +1,85 @@ +package progress + +import ( + "time" + + "github.com/fatih/color" + mpb "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" +) + +const ( + msgLength = 22 + barWidth = 28 + boxWidth = 64 + refreshRate = 150 * time.Millisecond +) + +// DefaultColors returns the default colors used for the progress bars. +func DefaultColors() *Colors { + return &Colors{ + Message: color.New(color.Faint), + Filler: color.New(color.FgGreen, color.Bold, color.Faint), + Size: color.New(color.Faint), + Percent: color.New(color.FgCyan, color.Faint), + } +} + +// Colors is the set of colors used for the progress bars. +type Colors struct { + Message *color.Color + Filler *color.Color + Size *color.Color + Percent *color.Color +} + +// EnableColor enables or disables color for the progress bars. +func (c *Colors) EnableColor(enable bool) { + if c == nil { + return + } + + if enable { + c.Message.EnableColor() + c.Filler.EnableColor() + c.Size.EnableColor() + c.Percent.EnableColor() + return + } + + c.Message.DisableColor() + c.Filler.DisableColor() + c.Size.DisableColor() + c.Percent.DisableColor() +} + +func colorize(decorator decor.Decorator, c *color.Color) decor.Decorator { + return decor.Meta(decorator, func(s string) string { + return c.Sprint(s) + }) +} + +func spinnerStyle(c *color.Color) mpb.SpinnerStyleComposer { + // REVISIT: maybe use ascii chars only, in case it's a weird terminal? + frames := []string{"∙∙∙", "●∙∙", "●∙∙", "∙●∙", "∙●∙", "∙∙●", "∙∙●", "∙∙∙"} + style := mpb.SpinnerStyle(frames...) + if c != nil { + style = style.Meta(func(s string) string { + return c.Sprint(s) + }) + } + return style +} + +func barStyle(c *color.Color) mpb.BarStyleComposer { + clr := func(s string) string { + return c.Sprint(s) + } + + frames := []string{"∙", "●", "●", "●", "∙"} + return mpb.BarStyle(). + Lbound(" ").Rbound(" "). + Filler("∙").FillerMeta(clr). + Padding(" "). + Tip(frames...).TipMeta(clr) +} diff --git a/libsq/source/cache.go b/libsq/source/cache.go index e726a1ed1..4d7c995c0 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -7,11 +7,9 @@ import ( "path/filepath" "strings" - "github.com/neilotoole/sq/libsq/core/options" - - "github.com/neilotoole/sq/libsq/core/ioz/checksum" - "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/stringz" ) diff --git a/libsq/source/download.go b/libsq/source/download.go index 51338c6f5..52b3d3663 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -15,8 +15,6 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/options" - "golang.org/x/exp/maps" "github.com/neilotoole/sq/libsq/core/errz" @@ -26,6 +24,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source/fetcher" ) diff --git a/libsq/source/files.go b/libsq/source/files.go index 50993f637..23639d533 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -12,20 +12,18 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" - - "github.com/neilotoole/sq/libsq/core/options" - "github.com/neilotoole/fscache" "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/stringz" ) From f4a519615e3369d2e203981f5d4f7b6ca07a4a59 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 8 Dec 2023 12:24:16 -0700 Subject: [PATCH 075/195] in a progress widget mess --- cli/cmd_x.go | 54 ++++++++++++++++++------ cli/output.go | 1 + drivers/csv/insert.go | 3 +- libsq/core/progress/bars.go | 75 +++++++++++++++++++++++++++++++++ libsq/core/progress/progress.go | 20 +++++++-- libsq/core/progress/style.go | 9 ++++ 6 files changed, 144 insertions(+), 18 deletions(-) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index ecb4133b4..a1b8f4843 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -3,6 +3,8 @@ package cli import ( "bufio" "fmt" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/progress" "os" "time" @@ -60,20 +62,10 @@ func execXLockSrcCmd(cmd *cobra.Command, args []string) error { } fmt.Fprintf(ru.Out, "Cache lock acquired for %s\n", src.Handle) - fmt.Fprintln(ru.Out, "Press ENTER to release lock and exit.") - - done := make(chan struct{}) - go func() { - // Wait for ENTER on stdin - buf := bufio.NewReader(ru.Stdin) - fmt.Fprint(ru.Out, " > ") - _, _ = buf.ReadBytes('\n') - close(done) - }() select { - case <-done: - fmt.Fprintln(ru.Out, "ENTER received, releasing lock") + case <-pressEnter(): + fmt.Fprintln(ru.Out, "\nENTER received, releasing lock") case <-ctx.Done(): fmt.Fprintln(ru.Out, "\nContext done, releasing lock") } @@ -101,7 +93,43 @@ func newXDevTestCmd() *cobra.Command { func execXDevTestCmd(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() + log := lg.FromContext(ctx) ru := run.FromContext(ctx) _ = ru - return nil + + d := time.Second * 5 + pb := progress.FromContext(ctx) + bar := pb.NewTimeoutWaiter("Locking @sakila", time.Now().Add(d)) + defer bar.Stop() + + select { + //case <-pressEnter(): + // bar.Stop() + // pb.Stop() + // fmt.Fprintln(ru.Out, "\nENTER received") + case <-ctx.Done(): + //bar.Stop() + //pb.Stop() + fmt.Fprintln(ru.Out, "Context done") + case <-time.After(d + time.Second*5): + //bar.Stop() + log.Warn("timed out, about to print something") + fmt.Fprintln(ru.Out, "Really timed out") + log.Warn("done printing") + } + + //bar.EwmaIncrInt64(rand.Int63n(5)+1, time.Since(start)) + fmt.Fprintln(ru.Out, "exiting") + return ctx.Err() +} + +func pressEnter() <-chan struct{} { + done := make(chan struct{}) + go func() { + buf := bufio.NewReader(os.Stdin) + fmt.Fprintf(os.Stdout, "\nPress [ENTER] to continue\n\n > ") + _, _ = buf.ReadBytes('\n') + close(done) + }() + return done } diff --git a/cli/output.go b/cli/output.go index 4405ebcc2..6f388e5e7 100644 --- a/cli/output.go +++ b/cli/output.go @@ -466,6 +466,7 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer out2 = ioz.NotifyOnceWriter(out2, func() { lg.FromContext(ctx).Debug("Output stream is being written to; removing progress widget") prog.Stop() + lg.FromContext(ctx).Debug("Output stream is being written to; stop has returned") }) cmd.SetContext(progress.NewContext(ctx, prog)) } diff --git a/drivers/csv/insert.go b/drivers/csv/insert.go index ed12225da..7f1821a29 100644 --- a/drivers/csv/insert.go +++ b/drivers/csv/insert.go @@ -4,8 +4,6 @@ import ( "context" "encoding/csv" "errors" - "io" - "github.com/neilotoole/sq/libsq" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" @@ -14,6 +12,7 @@ import ( "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/libsq/driver" + "io" ) // execInsert inserts the CSV records in readAheadRecs (followed by records diff --git a/libsq/core/progress/bars.go b/libsq/core/progress/bars.go index f78ed2679..b592bfd18 100644 --- a/libsq/core/progress/bars.go +++ b/libsq/core/progress/bars.go @@ -5,6 +5,7 @@ import ( "github.com/dustin/go-humanize/english" mpb "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" + "time" // NewByteCounter returns a new progress bar whose metric is the count // of bytes processed. If the size is unknown, set arg size to -1. The caller // is ultimately responsible for calling [Bar.Stop] on the returned Bar. @@ -115,3 +116,77 @@ func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { decorator = colorize(decorator, p.colors.Size) return p.newBar(msg, total, style, decorator) } + +func (p *Progress) NewTimeoutWaiter(msg string, expires time.Time) *Bar { + if p == nil { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + style := spinnerStyle(p.colors.Filler) + decorator := decor.Any(func(statistics decor.Statistics) string { + remaining := expires.Sub(time.Now()) + switch { + case remaining > 0: + return p.colors.Size.Sprintf("timeout in %s", remaining.Round(time.Second)) + case remaining > -time.Second: + // We do the extra second to prevent a "flash" of + // the timeout message. + return p.colors.Size.Sprint("timeout in 0s") + default: + return p.colors.Warning.Sprintf("timed out") + } + }) + + start := time.Now() + total := expires.Sub(start) + var lastUpdate time.Duration + _ = lastUpdate + b := p.newBar(msg, int64(total), style, decorator) + + //go func() { + // t := time.NewTimer(total) + // defer t.Stop() + // + // select { + // case <-t.C: + // return + // case <-b.p.ctx.Done(): + // return + // case <-b.delayCh: + // } + // + // if b.stopped { + // return + // } + // + // b.initBarOnce.Do(b.initBar) + // + // for { + // select { + // case <-t.C: + // return + // case <-b.p.ctx.Done(): + // return + // default: + // } + // + // if b.stopped { + // return + // } + // + // now := time.Now() + // delta := expires.Sub(now) + // amount := delta - lastUpdate + // lastUpdate += amount + // + // //b.bar.EwmaIncrement(amount) + // b.IncrBy(int(amount)) + // time.Sleep(refreshRate) + // } + //}() + + return b +} diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 6af6fdd1b..57acfa51c 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -163,6 +163,8 @@ func (p *Progress) Stop() { p.mu.Lock() defer p.mu.Unlock() + lg.FromContext(p.ctx).Debug("Stopping progress widget: enter") + if p.stopped { return } @@ -180,17 +182,21 @@ func (p *Progress) Stop() { for _, b := range p.bars { if b.bar != nil { + b.bar.SetTotal(-1, true) b.bar.Abort(true) } } for _, b := range p.bars { - if b.bar != nil { - b.bar.Wait() - } + b.doStop() + //if b.bar != nil { + // b.bar.Wait() + //} } p.pc.Wait() + lg.FromContext(p.ctx).Debug("Stopping progress widget: exit") + //time.Sleep(refreshRate * 2) } // newBar returns a new Bar. This function must only be called from @@ -321,6 +327,14 @@ func (b *Bar) Stop() { b.p.mu.Lock() defer b.p.mu.Unlock() + b.doStop() +} + +func (b *Bar) doStop() { + if b == nil { + return + } + if b.bar == nil { b.stopped = true return diff --git a/libsq/core/progress/style.go b/libsq/core/progress/style.go index 11fe87099..c0bbd8567 100644 --- a/libsq/core/progress/style.go +++ b/libsq/core/progress/style.go @@ -22,6 +22,8 @@ func DefaultColors() *Colors { Filler: color.New(color.FgGreen, color.Bold, color.Faint), Size: color.New(color.Faint), Percent: color.New(color.FgCyan, color.Faint), + Warning: color.New(color.FgYellow), + Error: color.New(color.FgRed, color.Bold), } } @@ -31,6 +33,8 @@ type Colors struct { Filler *color.Color Size *color.Color Percent *color.Color + Warning *color.Color + Error *color.Color } // EnableColor enables or disables color for the progress bars. @@ -44,6 +48,9 @@ func (c *Colors) EnableColor(enable bool) { c.Filler.EnableColor() c.Size.EnableColor() c.Percent.EnableColor() + c.Warning.EnableColor() + c.Error.EnableColor() + return } @@ -51,6 +58,8 @@ func (c *Colors) EnableColor(enable bool) { c.Filler.DisableColor() c.Size.DisableColor() c.Percent.DisableColor() + c.Warning.DisableColor() + c.Error.DisableColor() } func colorize(decorator decor.Decorator, c *color.Color) decor.Decorator { From a219d46f2eafeed8382877f5500e6ea1d4dffea4 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 8 Dec 2023 13:23:08 -0700 Subject: [PATCH 076/195] progress closer to working --- cli/cmd_x.go | 15 +++++----- cli/output.go | 1 - drivers/csv/insert.go | 3 +- drivers/xlsx/ingest.go | 1 + libsq/core/progress/bars.go | 49 ++------------------------------- libsq/core/progress/progress.go | 40 ++++++++++++++++++--------- libsq/driver/grips.go | 2 +- libsq/driver/record.go | 2 ++ 8 files changed, 44 insertions(+), 69 deletions(-) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index a1b8f4843..a4e027f35 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -3,11 +3,12 @@ package cli import ( "bufio" "fmt" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/progress" "os" "time" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/run" @@ -103,22 +104,22 @@ func execXDevTestCmd(cmd *cobra.Command, _ []string) error { defer bar.Stop() select { - //case <-pressEnter(): + // case <-pressEnter(): // bar.Stop() // pb.Stop() // fmt.Fprintln(ru.Out, "\nENTER received") case <-ctx.Done(): - //bar.Stop() - //pb.Stop() + // bar.Stop() + // pb.Stop() fmt.Fprintln(ru.Out, "Context done") case <-time.After(d + time.Second*5): - //bar.Stop() + // bar.Stop() log.Warn("timed out, about to print something") fmt.Fprintln(ru.Out, "Really timed out") log.Warn("done printing") } - //bar.EwmaIncrInt64(rand.Int63n(5)+1, time.Since(start)) + // bar.EwmaIncrInt64(rand.Int63n(5)+1, time.Since(start)) fmt.Fprintln(ru.Out, "exiting") return ctx.Err() } diff --git a/cli/output.go b/cli/output.go index 6f388e5e7..4405ebcc2 100644 --- a/cli/output.go +++ b/cli/output.go @@ -466,7 +466,6 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer out2 = ioz.NotifyOnceWriter(out2, func() { lg.FromContext(ctx).Debug("Output stream is being written to; removing progress widget") prog.Stop() - lg.FromContext(ctx).Debug("Output stream is being written to; stop has returned") }) cmd.SetContext(progress.NewContext(ctx, prog)) } diff --git a/drivers/csv/insert.go b/drivers/csv/insert.go index 7f1821a29..ed12225da 100644 --- a/drivers/csv/insert.go +++ b/drivers/csv/insert.go @@ -4,6 +4,8 @@ import ( "context" "encoding/csv" "errors" + "io" + "github.com/neilotoole/sq/libsq" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" @@ -12,7 +14,6 @@ import ( "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/libsq/driver" - "io" ) // execInsert inserts the CSV records in readAheadRecs (followed by records diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 83a4045a3..e265b2984 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -134,6 +134,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, x var ingestCount, skipped int for i := range sheetTbls { + time.Sleep(progress.DebugDelay) if sheetTbls[i] == nil { // tblDef can be nil if its sheet is empty (has no data). skipped++ diff --git a/libsq/core/progress/bars.go b/libsq/core/progress/bars.go index b592bfd18..dcdd195a5 100644 --- a/libsq/core/progress/bars.go +++ b/libsq/core/progress/bars.go @@ -1,11 +1,12 @@ package progress import ( + "time" + humanize "github.com/dustin/go-humanize" "github.com/dustin/go-humanize/english" mpb "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" - "time" // NewByteCounter returns a new progress bar whose metric is the count // of bytes processed. If the size is unknown, set arg size to -1. The caller // is ultimately responsible for calling [Bar.Stop] on the returned Bar. @@ -127,7 +128,7 @@ func (p *Progress) NewTimeoutWaiter(msg string, expires time.Time) *Bar { style := spinnerStyle(p.colors.Filler) decorator := decor.Any(func(statistics decor.Statistics) string { - remaining := expires.Sub(time.Now()) + remaining := time.Until(expires) switch { case remaining > 0: return p.colors.Size.Sprintf("timeout in %s", remaining.Round(time.Second)) @@ -142,51 +143,7 @@ func (p *Progress) NewTimeoutWaiter(msg string, expires time.Time) *Bar { start := time.Now() total := expires.Sub(start) - var lastUpdate time.Duration - _ = lastUpdate b := p.newBar(msg, int64(total), style, decorator) - //go func() { - // t := time.NewTimer(total) - // defer t.Stop() - // - // select { - // case <-t.C: - // return - // case <-b.p.ctx.Done(): - // return - // case <-b.delayCh: - // } - // - // if b.stopped { - // return - // } - // - // b.initBarOnce.Do(b.initBar) - // - // for { - // select { - // case <-t.C: - // return - // case <-b.p.ctx.Done(): - // return - // default: - // } - // - // if b.stopped { - // return - // } - // - // now := time.Now() - // delta := expires.Sub(now) - // amount := delta - lastUpdate - // lastUpdate += amount - // - // //b.bar.EwmaIncrement(amount) - // b.IncrBy(int(amount)) - // time.Sleep(refreshRate) - // } - //}() - return b } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 57acfa51c..673e30f78 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -29,6 +29,13 @@ import ( "github.com/neilotoole/sq/libsq/core/stringz" ) +// DebugDelay is a duration that parts of the codebase sleep for to +// facilitate testing the progress impl. It should be removed before +// release. +// +// Deprecated: This is a temporary hack for testing. +const DebugDelay = time.Millisecond * 20 + type ctxKey struct{} // NewContext returns ctx with p added as a value. @@ -137,6 +144,7 @@ type Progress struct { pc *mpb.Progress // pcInit is the func that lazily initializes pc. + // FIXME: Do we even need the lazily initialized pc now? pcInit func() // delay is the duration to wait before rendering a progress bar. @@ -153,50 +161,57 @@ type Progress struct { } // Stop waits for all bars to complete and finally shuts down the -// container. After this method has been called, there is no way -// to reuse the Progress instance. +// progress container. After this method has been called, there is +// no way to reuse the Progress instance. func (p *Progress) Stop() { if p == nil { return } p.mu.Lock() - defer p.mu.Unlock() - - lg.FromContext(p.ctx).Debug("Stopping progress widget: enter") + p.doStop() + p.mu.Unlock() +} +// doStop is probably needlessly complex, but at the time it was written, +// there was a bug in the mpb package (to do with delayed render and abort), +// and so was created an extra-paranoid workaround. +func (p *Progress) doStop() { if p.stopped { return } p.stopped = true - p.cancelFn() if p.pc == nil { + p.cancelFn() return } if len(p.bars) == 0 { + // p.pc.Wait() FIXME: Does this need to happen + p.cancelFn() return } for _, b := range p.bars { + // We abort each of the bars here, before we call b.doStop() below. + // In theory, this gives the bar abortion process a head start before + // b.bar.Wait() is invoked by b.doStop(). This may be completely + // unnecessary, but it doesn't seem to hurt. if b.bar != nil { - b.bar.SetTotal(-1, true) b.bar.Abort(true) } } for _, b := range p.bars { b.doStop() - //if b.bar != nil { - // b.bar.Wait() - //} } p.pc.Wait() - lg.FromContext(p.ctx).Debug("Stopping progress widget: exit") - //time.Sleep(refreshRate * 2) + // Important: we must call cancelFn after pc.Wait() or the bars + // may not be removed from the terminal. + p.cancelFn() } // newBar returns a new Bar. This function must only be called from @@ -341,7 +356,6 @@ func (b *Bar) doStop() { } if !b.stopped { - b.bar.SetTotal(-1, true) b.bar.Abort(true) } b.stopped = true diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 32f01cba7..7534b2886 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -234,7 +234,7 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, } if err = lock.Lock(ctx, time.Second*5); err != nil { - return nil, errz.Wrap(err, "acquire cache lock") + return nil, errz.Wrap(err, src.Handle+": acquire cache lock") } defer func() { diff --git a/libsq/driver/record.go b/libsq/driver/record.go index d2aa36bbe..67da2d2d8 100644 --- a/libsq/driver/record.go +++ b/libsq/driver/record.go @@ -468,6 +468,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, bi.written.Add(affected) pbar.IncrBy(int(affected)) + time.Sleep(progress.DebugDelay) if rec == nil { // recCh is closed (coincidentally exactly on the @@ -511,6 +512,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, bi.written.Add(affected) pbar.IncrBy(int(affected)) + time.Sleep(progress.DebugDelay) // We're done return From a3631207d4c7d98c691da47f100b4a200de91010 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 8 Dec 2023 13:36:30 -0700 Subject: [PATCH 077/195] more progress --- libsq/core/progress/bars.go | 23 +++++++++++++++-------- libsq/core/progress/style.go | 28 ++++++++++++++++------------ 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/libsq/core/progress/bars.go b/libsq/core/progress/bars.go index dcdd195a5..313bb3c17 100644 --- a/libsq/core/progress/bars.go +++ b/libsq/core/progress/bars.go @@ -118,6 +118,14 @@ func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { return p.newBar(msg, total, style, decorator) } +// NewTimeoutWaiter returns a new indeterminate bar whose label is the +// amount of time remaining until expires. It produces output similar to: +// +// Locking @sakila ●∙∙ timeout in 7s +// +// The caller is ultimately responsible for calling [Bar.Stop] on +// the returned bar, although the bar will also be stopped when the +// parent Progress stops. func (p *Progress) NewTimeoutWaiter(msg string, expires time.Time) *Bar { if p == nil { return nil @@ -126,24 +134,23 @@ func (p *Progress) NewTimeoutWaiter(msg string, expires time.Time) *Bar { p.mu.Lock() defer p.mu.Unlock() - style := spinnerStyle(p.colors.Filler) + style := spinnerStyle(p.colors.Waiting) decorator := decor.Any(func(statistics decor.Statistics) string { remaining := time.Until(expires) switch { case remaining > 0: return p.colors.Size.Sprintf("timeout in %s", remaining.Round(time.Second)) case remaining > -time.Second: - // We do the extra second to prevent a "flash" of - // the timeout message. + // We do the extra second to prevent a "flash" of the timeout message, + // and it also prevents "timeout in -1s" etc. This situation should be + // rare; the caller should have already called Stop() on the Progress + // when the timeout happened, but we'll play it safe. return p.colors.Size.Sprint("timeout in 0s") default: return p.colors.Warning.Sprintf("timed out") } }) - start := time.Now() - total := expires.Sub(start) - b := p.newBar(msg, int64(total), style, decorator) - - return b + total := time.Until(expires) + return p.newBar(msg, int64(total), style, decorator) } diff --git a/libsq/core/progress/style.go b/libsq/core/progress/style.go index c0bbd8567..b0de77b08 100644 --- a/libsq/core/progress/style.go +++ b/libsq/core/progress/style.go @@ -18,23 +18,25 @@ const ( // DefaultColors returns the default colors used for the progress bars. func DefaultColors() *Colors { return &Colors{ - Message: color.New(color.Faint), + Error: color.New(color.FgRed, color.Bold), Filler: color.New(color.FgGreen, color.Bold, color.Faint), - Size: color.New(color.Faint), + Message: color.New(color.Faint), Percent: color.New(color.FgCyan, color.Faint), + Size: color.New(color.Faint), + Waiting: color.New(color.FgYellow, color.Faint), Warning: color.New(color.FgYellow), - Error: color.New(color.FgRed, color.Bold), } } // Colors is the set of colors used for the progress bars. type Colors struct { - Message *color.Color + Error *color.Color Filler *color.Color - Size *color.Color + Message *color.Color Percent *color.Color + Size *color.Color + Waiting *color.Color Warning *color.Color - Error *color.Color } // EnableColor enables or disables color for the progress bars. @@ -44,22 +46,24 @@ func (c *Colors) EnableColor(enable bool) { } if enable { - c.Message.EnableColor() + c.Error.EnableColor() c.Filler.EnableColor() - c.Size.EnableColor() + c.Message.EnableColor() c.Percent.EnableColor() + c.Size.EnableColor() + c.Waiting.EnableColor() c.Warning.EnableColor() - c.Error.EnableColor() return } - c.Message.DisableColor() + c.Error.DisableColor() c.Filler.DisableColor() - c.Size.DisableColor() + c.Message.DisableColor() c.Percent.DisableColor() + c.Size.DisableColor() + c.Waiting.DisableColor() c.Warning.DisableColor() - c.Error.DisableColor() } func colorize(decorator decor.Decorator, c *color.Color) decor.Decorator { From b7231981c19d21ef602d6bd391449239b19710bd Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 10 Dec 2023 08:30:29 -0700 Subject: [PATCH 078/195] wip: more progress --- cli/cli.go | 28 +++++++++++++---- cli/error.go | 23 +++++++++++--- cli/options.go | 1 + cli/output.go | 23 +++++++------- cli/run.go | 5 +-- cli/run/run.go | 18 +++++++++-- go.mod | 27 ++++++++-------- go.sum | 6 ++-- libsq/core/progress/progress.go | 15 +++++++-- libsq/core/progress/style.go | 2 +- libsq/core/stringz/stringz.go | 36 ++++++++++++++++++---- libsq/core/stringz/stringz_test.go | 49 +++++++++++++++++++++++++----- libsq/driver/grips.go | 11 ++++++- libsq/driver/ingest.go | 2 ++ libsq/source/cache.go | 16 ++++++++++ testh/tu/tutil.go | 2 +- 16 files changed, 200 insertions(+), 64 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index fccca7a25..6fd15e513 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -60,21 +60,30 @@ var errNoMsg = errors.New("") func Execute(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args []string) error { ru, log, err := newRun(ctx, stdin, stdout, stderr, args) if err != nil { + // This may be unnecessary, but we are extra-paranoid about + // closing ru before exiting the program. + if closeErr := ru.Close(); closeErr != nil && log != nil { + log.Error("Failed to close run", lga.Err, closeErr) + } + if ru.LogCloser != nil { + _ = ru.LogCloser() + } printError(ctx, ru, err) return err } - defer ru.Close() // ok to call ru.Close on nil ru - ctx = lg.NewContext(ctx, log) - return ExecuteWith(ctx, ru, args) } // ExecuteWith invokes the cobra CLI framework, ultimately -// resulting in a command being executed. The caller must -// invoke ru.Close. +// resulting in a command being executed. This function always closes ru. func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { + defer func() { + if ru != nil && ru.LogCloser != nil { + _ = ru.LogCloser() + } + }() ctx = options.NewContext(ctx, ru.Config.Options) log := lg.FromContext(ctx) log.Info("EXECUTE", "args", strings.Join(args, " ")) @@ -122,6 +131,7 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { // cobra still returns cmd though. It should be // the root cmd. if cmd == nil || cmd.Name() != rootCmd.Name() { + lg.WarnIfCloseError(log, "Problem closing run", ru) // Not sure if this can happen anymore? Can prob delete? panic(fmt.Sprintf("bad cobra cmd state: %v", cmd)) } @@ -132,6 +142,7 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { // doesn't want the first args element. effectiveArgs := append([]string{"slq"}, args...) if effectiveArgs, err = preprocessFlagArgVars(effectiveArgs); err != nil { + lg.WarnIfCloseError(log, "Problem closing run", ru) return err } rootCmd.SetArgs(effectiveArgs) @@ -142,6 +153,7 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { // we redirect to "slq" cmd. effectiveArgs := append([]string{"slq"}, args...) if effectiveArgs, err = preprocessFlagArgVars(effectiveArgs); err != nil { + lg.WarnIfCloseError(log, "Problem closing run", ru) return err } rootCmd.SetArgs(effectiveArgs) @@ -159,8 +171,12 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { // Execute rootCmd; cobra will find the appropriate // sub-command, and ultimately execute that command. err = rootCmd.ExecuteContext(ctx) + lg.WarnIfCloseError(log, "Problem closing run", ru) if err != nil { - printError(ctx, ru, err) + ctx2 := rootCmd.Context() // FIXME: delete + _ = ctx2 + + printError(ctx2, ru, err) } return err diff --git a/cli/error.go b/cli/error.go index ad218b951..c6359b02d 100644 --- a/cli/error.go +++ b/cli/error.go @@ -4,10 +4,10 @@ import ( "context" "errors" "fmt" - "os" - + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/spf13/cobra" "github.com/spf13/pflag" + "os" "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/output/format" @@ -25,7 +25,10 @@ import ( // ru or any of its fields are nil). func printError(ctx context.Context, ru *run.Run, err error) { log := lg.FromContext(ctx) - + log.Warn("printError called", lga.Err, err) // FIXME: delete + //debug.PrintStack() + //stack := errz.Stack(err) + //fmt.Fprintln(ru.Out, "printError stack", "stack", stack) if err == nil { log.Warn("printError called with nil error") return @@ -37,6 +40,7 @@ func printError(ctx context.Context, ru *run.Run, err error) { } switch { + // Friendlier messages for context errors. default: case errors.Is(err, context.Canceled): err = errz.New("canceled") @@ -89,8 +93,19 @@ func printError(ctx context.Context, ru *run.Run, err error) { opts, _ = ru.OptionsRegistry.Process(opts) } + // getPrinting requires a cleanup.Cleanup, so we get or create one. + var clnup *cleanup.Cleanup + if ru != nil && ru.Cleanup != nil { + clnup = ru.Cleanup + } else { + clnup = cleanup.New() + } // getPrinting works even if cmd is nil - pr, _, errOut := getPrinting(cmd, opts, os.Stdout, os.Stderr) + pr, _, errOut := getPrinting(cmd, clnup, opts, os.Stdout, os.Stderr) + // Execute the cleanup before we print the error. + if cleanErr := clnup.Run(); cleanErr != nil { + log.Error("Cleanup failed", lga.Err, cleanErr) + } if bootstrapIsFormatJSON(ru) { // The user wants JSON, either via defaults or flags. diff --git a/cli/options.go b/cli/options.go index 37b9c3b8c..5cef01174 100644 --- a/cli/options.go +++ b/cli/options.go @@ -188,6 +188,7 @@ func RegisterDefaultOpts(reg *options.Registry) { OptTuningFlushThreshold, driver.OptIngestHeader, driver.OptIngestCache, + source.OptCacheLockTimeout, driver.OptIngestColRename, driver.OptIngestSampleSize, csv.OptDelim, diff --git a/cli/output.go b/cli/output.go index 4405ebcc2..9f1eec5a7 100644 --- a/cli/output.go +++ b/cli/output.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "github.com/neilotoole/sq/libsq/core/cleanup" "io" "os" "strings" @@ -261,10 +262,9 @@ Note that this option is no-op if the rendered value is not an integer. // newWriters returns an output.Writers instance configured per defaults and/or // flags from cmd. The returned out2/errOut2 values may differ // from the out/errOut args (e.g. decorated to support colorization). -func newWriters(cmd *cobra.Command, o options.Options, out, errOut io.Writer, -) (w *output.Writers, out2, errOut2 io.Writer) { +func newWriters(cmd *cobra.Command, clnup *cleanup.Cleanup, o options.Options, out, errOut io.Writer) (w *output.Writers, out2, errOut2 io.Writer) { var pr *output.Printing - pr, out2, errOut2 = getPrinting(cmd, o, out, errOut) + pr, out2, errOut2 = getPrinting(cmd, clnup, o, out, errOut) log := logFrom(cmd) // Package tablew has writer impls for each of the writer interfaces, @@ -375,8 +375,7 @@ func getRecordWriterFunc(f format.Format) output.NewRecordWriterFunc { // be absolutely bulletproof, as it's called by all commands, as well // as by the error handling mechanism. So, be sure to always check // for nil cmd, nil cmd.Context, etc. -func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer, -) (pr *output.Printing, out2, errOut2 io.Writer) { +func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Options, out, errOut io.Writer) (pr *output.Printing, out2, errOut2 io.Writer) { pr = output.NewPrinting() pr.FormatDatetime = timez.FormatFunc(OptDatetimeFormat.Get(opts)) @@ -422,10 +421,11 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer progColors.EnableColor(false) ctx := cmd.Context() renderDelay := OptProgressDelay.Get(opts) - prog := progress.New(ctx, errOut, renderDelay, progColors) + pb := progress.New(ctx, errOut, renderDelay, progColors) + clnup.Add(pb.Stop) // On first write to stdout, we remove the progress widget. - out2 = ioz.NotifyOnceWriter(out2, prog.Stop) - cmd.SetContext(progress.NewContext(ctx, prog)) + out2 = ioz.NotifyOnceWriter(out2, pb.Stop) + cmd.SetContext(progress.NewContext(ctx, pb)) } return pr, out2, errOut2 @@ -460,14 +460,15 @@ func getPrinting(cmd *cobra.Command, opts options.Options, out, errOut io.Writer ctx := cmd.Context() renderDelay := OptProgressDelay.Get(opts) - prog := progress.New(ctx, errOut2, renderDelay, progColors) + pb := progress.New(ctx, errOut2, renderDelay, progColors) + clnup.Add(pb.Stop) // On first write to stdout, we remove the progress widget. out2 = ioz.NotifyOnceWriter(out2, func() { lg.FromContext(ctx).Debug("Output stream is being written to; removing progress widget") - prog.Stop() + pb.Stop() }) - cmd.SetContext(progress.NewContext(ctx, prog)) + cmd.SetContext(progress.NewContext(ctx, pb)) } logFrom(cmd).Debug("Constructed output.Printing", lga.Val, pr) diff --git a/cli/run.go b/cli/run.go index 515af4e5b..c4e92c384 100644 --- a/cli/run.go +++ b/cli/run.go @@ -83,7 +83,8 @@ func newRun(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args args, ru.OptionsRegistry, upgrades) log, logHandler, logCloser, logErr := defaultLogging(ctx, args, ru.Config) - ru.Cleanup = cleanup.New().AddE(logCloser) + ru.Cleanup = cleanup.New() + ru.LogCloser = logCloser if logErr != nil { stderrLog, h := stderrLogger() _ = logbuf.Flush(ctx, h) @@ -270,7 +271,7 @@ func preRun(cmd *cobra.Command, ru *run.Run) error { if err != nil { return err } - ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, cmdOpts, ru.Out, ru.ErrOut) + ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, ru.Cleanup, cmdOpts, ru.Out, ru.ErrOut) return FinishRunInit(ctx, ru) } diff --git a/cli/run/run.go b/cli/run/run.go index 44ce55a92..9d72278f1 100644 --- a/cli/run/run.go +++ b/cli/run/run.go @@ -4,6 +4,7 @@ package run import ( "context" + "github.com/neilotoole/sq/libsq/core/lg" "io" "os" @@ -80,20 +81,31 @@ type Run struct { // the CLI uses to print output. Writers *output.Writers - // Cleanup holds cleanup functions. + // Cleanup holds cleanup functions, except log closing, which + // is held by LogCloser. Cleanup *cleanup.Cleanup + + // LogCloser contains any log-closing action (such as closing + // a log file). It may be nil. Execution of this function + // should be more-or-less the final cleanup action performed by the CLI, + // and absolutely must happen after all other cleanup actions. + LogCloser func() error } // Close should be invoked to dispose of any open resources // held by ru. If an error occurs during Close and ru.Log // is not nil, that error is logged at WARN level before -// being returned. +// being returned. Note that Run.LogCloser must be invoked separately. func (ru *Run) Close() error { if ru == nil { return nil } - return errz.Wrap(ru.Cleanup.Run(), "Close Run") + if ru.Cmd != nil { + lg.FromContext(ru.Cmd.Context()).Debug("Closing run") + } + + return errz.Wrap(ru.Cleanup.Run(), "close run") } // NewQueryContext returns a *libsq.QueryContext constructed from ru. diff --git a/go.mod b/go.mod index 504998b45..d69ab6dc7 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,17 @@ module github.com/neilotoole/sq go 1.21 +// See: https://github.com/vbauerster/mpb/issues/136 +require github.com/vbauerster/mpb/v8 v8.7.1-0.20231206170755-3a4a40c73c35 + +// See: https://github.com/djherbis/fscache/pull/21 +require github.com/neilotoole/fscache v0.0.0-20231203162946-c9808f16552e + +require github.com/djherbis/stream v1.4.0 // indirect + +// See: https://github.com/djherbis/stream/pull/11 +replace github.com/djherbis/stream v1.4.0 => github.com/neilotoole/djherbis-stream v0.0.0-20231203160853-609f47afedda + require ( github.com/Masterminds/sprig/v3 v3.2.3 github.com/a8m/tree v0.0.0-20230208161321-36ae24ddad15 @@ -18,7 +29,6 @@ require ( github.com/h2non/filetype v1.1.3 github.com/jackc/pgx/v5 v5.5.0 github.com/mattn/go-colorable v0.1.13 - github.com/mattn/go-isatty v0.0.20 github.com/mattn/go-runewidth v0.0.15 github.com/mattn/go-sqlite3 v1.14.18 github.com/microsoft/go-mssqldb v1.6.0 @@ -68,6 +78,7 @@ require ( github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/copystructure v1.0.0 // indirect github.com/mitchellh/reflectwalk v1.0.0 // indirect github.com/moby/term v0.5.0 // indirect @@ -88,17 +99,3 @@ require ( golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) - -// See: https://github.com/vbauerster/mpb/issues/136 -require github.com/vbauerster/mpb/v8 v8.7.1-0.20231205062852-da3162c67234 - -// See: https://github.com/djherbis/fscache/pull/21 -require github.com/neilotoole/fscache v0.0.0-20231203162946-c9808f16552e - -require ( - github.com/djherbis/stream v1.4.0 // indirect - github.com/mrz1836/go-sanitize v1.3.1 // indirect -) - -// See: https://github.com/djherbis/stream/pull/11 -replace github.com/djherbis/stream v1.4.0 => github.com/neilotoole/djherbis-stream v0.0.0-20231203160853-609f47afedda diff --git a/go.sum b/go.sum index 195e2f39d..6866c715d 100644 --- a/go.sum +++ b/go.sum @@ -113,8 +113,6 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 h1:RWengNIwukTxcDr9M+97sNutRR1RKhG96O6jWumTTnw= github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826/go.mod h1:TaXosZuwdSHYgviHp1DAtfrULt5eUgsSMsZf+YrPgl8= -github.com/mrz1836/go-sanitize v1.3.1 h1:bTxpzDXzGh9cp3XLTeVKgL2iLqEwCaLqqe+3BmpnCbo= -github.com/mrz1836/go-sanitize v1.3.1/go.mod h1:Js6Gq1uiarNReoOeOKxPXxNpKy1FRlbgDDZnJG4THdM= github.com/muesli/mango v0.1.0 h1:DZQK45d2gGbql1arsYA4vfg4d7I9Hfx5rX/GCmzsAvI= github.com/muesli/mango v0.1.0/go.mod h1:5XFpbC8jY5UUv89YQciiXNlbi+iJgt29VDC5xbzrLL4= github.com/muesli/mango-cobra v1.2.0 h1:DQvjzAM0PMZr85Iv9LIMaYISpTOliMEg+uMFtNbYvWg= @@ -186,8 +184,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/vbauerster/mpb/v8 v8.7.1-0.20231205062852-da3162c67234 h1:ZsOQFNOwxbDqlxHc9wUW2skA4QMXMZyCOVngFdbrzJE= -github.com/vbauerster/mpb/v8 v8.7.1-0.20231205062852-da3162c67234/go.mod h1:0RgdqeTpu6cDbdWeSaDvEvfgm9O598rBnRZ09HKaV0k= +github.com/vbauerster/mpb/v8 v8.7.1-0.20231206170755-3a4a40c73c35 h1:MMBVE5bui8tBLZz7L4K9MX+ZBQ4eMsrX1iMCg0Ex6Lo= +github.com/vbauerster/mpb/v8 v8.7.1-0.20231206170755-3a4a40c73c35/go.mod h1:0RgdqeTpu6cDbdWeSaDvEvfgm9O598rBnRZ09HKaV0k= github.com/xo/dburl v0.19.1 h1:z/K2i8zVf6aRwQ8Szz7MGEUw0VC2472D9SlBqdHDQCU= github.com/xo/dburl v0.19.1/go.mod h1:B7/G9FGungw6ighV8xJNwWYQPMfn3gsi2sn5SE8Bzco= github.com/xuri/efp v0.0.0-20230802181842-ad255f2331ca h1:uvPMDVyP7PXMMioYdyPH+0O+Ta/UO1WFfNYMO3Wz0eg= diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 673e30f78..92a6b710f 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -171,6 +171,7 @@ func (p *Progress) Stop() { p.mu.Lock() p.doStop() p.mu.Unlock() + lg.FromContext(p.ctx).Debug("Stopped progress widget") } // doStop is probably needlessly complex, but at the time it was written, @@ -239,11 +240,12 @@ func (p *Progress) newBar(msg string, total int64, total = 0 } + // We want the bar message to be a consistent width. switch { case len(msg) < msgLength: msg += strings.Repeat(" ", msgLength-len(msg)) case len(msg) > msgLength: - msg = stringz.TrimLenMiddle(msg, msgLength) + msg = stringz.Ellipsify(msg, msgLength) } b := &Bar{ @@ -343,6 +345,7 @@ func (b *Bar) Stop() { defer b.p.mu.Unlock() b.doStop() + lg.FromContext(b.p.ctx).Debug("Stopped progress bar") } func (b *Bar) doStop() { @@ -355,9 +358,15 @@ func (b *Bar) doStop() { return } - if !b.stopped { - b.bar.Abort(true) + if b.stopped { + return } + + //if !b.stopped { + // b.bar.Abort(true) + //} + b.bar.SetTotal(-1, true) + b.bar.Abort(true) b.stopped = true b.bar.Wait() diff --git a/libsq/core/progress/style.go b/libsq/core/progress/style.go index b0de77b08..bbae96fde 100644 --- a/libsq/core/progress/style.go +++ b/libsq/core/progress/style.go @@ -9,7 +9,7 @@ import ( ) const ( - msgLength = 22 + msgLength = 28 barWidth = 28 boxWidth = 64 refreshRate = 150 * time.Millisecond diff --git a/libsq/core/stringz/stringz.go b/libsq/core/stringz/stringz.go index 424fb8615..2db3ad4ab 100644 --- a/libsq/core/stringz/stringz.go +++ b/libsq/core/stringz/stringz.go @@ -457,21 +457,45 @@ func TrimLen(s string, maxLen int) string { return s[:maxLen] } -// TrimLenMiddle returns s but with a maximum length of maxLen, +// Ellipsify shortens s to a length of maxLen by cutting the middle and +// inserting an ellipsis rune "…".This is the actual ellipsis rune, not +// three periods. For very short strings, the ellipsis may be elided. +// +// Be warned, Ellipsify may not be unicode-safe. Use at your own risk. +// +// See also: [EllipsifyASCII]. +func Ellipsify(s string, width int) string { + const e = "…" + if width <= 0 { + return "" + } + length := len(s) + + if length <= width { + return s + } + + trimLen := ((width + 1) / 2) - 1 + return s[:trimLen+1-(width%2)] + e + s[len(s)-trimLen:] +} + +// EllipsifyASCII returns s but with a maximum length of maxLen, // with the middle of s replaced with "...". If maxLen is a small // number, the ellipsis may be shorter, e.g. a single char. // This func is only tested with ASCII chars; results are not // guaranteed for multibyte runes. -func TrimLenMiddle(s string, maxLen int) string { +// +// See also: [Ellipsify]. +func EllipsifyASCII(s string, width int) string { length := len(s) - if maxLen <= 0 { + if width <= 0 { return "" } - if length <= maxLen { + if length <= width { return s } - switch maxLen { + switch width { case 1: return s[0:1] case 2: @@ -485,7 +509,7 @@ func TrimLenMiddle(s string, maxLen int) string { default: } - trimLen := ((maxLen + 1) / 2) - 2 + trimLen := ((width + 1) / 2) - 2 return s[:trimLen] + "..." + s[len(s)-trimLen:] } diff --git a/libsq/core/stringz/stringz_test.go b/libsq/core/stringz/stringz_test.go index 6319cd0ed..859c17f39 100644 --- a/libsq/core/stringz/stringz_test.go +++ b/libsq/core/stringz/stringz_test.go @@ -1,14 +1,13 @@ package stringz_test import ( - "strconv" - "strings" - "testing" - "github.com/samber/lo" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "strconv" + "strings" + "testing" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/testh/tu" @@ -557,10 +556,46 @@ func TestShellEscape(t *testing.T) { } } -// TestTrimLenMiddle tests TrimLenMiddle. It verifies that +func TestEllipsify(t *testing.T) { + testCases := []struct { + input string + maxLen int + want string + }{ + {input: "", maxLen: 0, want: ""}, + {input: "", maxLen: 1, want: ""}, + {input: "abc", maxLen: 1, want: "…"}, + {input: "abc", maxLen: 2, want: "a…"}, + {input: "abcdefghijk", maxLen: 1, want: "…"}, + {input: "abcdefghijk", maxLen: 2, want: "a…"}, + {input: "abcdefghijk", maxLen: 3, want: "a…k"}, + {input: "abcdefghijk", maxLen: 4, want: "ab…k"}, + {input: "abcdefghijk", maxLen: 5, want: "ab…jk"}, + {input: "abcdefghijk", maxLen: 6, want: "abc…jk"}, + {input: "abcdefghijk", maxLen: 7, want: "abc…ijk"}, + {input: "abcdefghijk", maxLen: 8, want: "abcd…ijk"}, + {input: "abcdefghijk", maxLen: 9, want: "abcd…hijk"}, + } + + for i, tc := range testCases { + tc := tc + t.Run(tu.Name(i, tc.input, tc.maxLen), func(t *testing.T) { + got := stringz.Ellipsify(tc.input, tc.maxLen) + t.Logf("%12q --> %12q", tc.input, got) + assert.Equal(t, tc.want, got) + }) + } + + t.Run("test negative", func(t *testing.T) { + got := stringz.Ellipsify("abc", -1) + require.Equal(t, "", got) + }) +} + +// TestEllipsifyASCII tests EllipsifyASCII. It verifies that // the function trims the middle of a string, leaving the // start and end intact. -func TestTrimLenMiddle(t *testing.T) { +func TestEllipsifyASCII(t *testing.T) { testCases := []struct { input string maxLen int @@ -582,7 +617,7 @@ func TestTrimLenMiddle(t *testing.T) { for i, tc := range testCases { tc := tc t.Run(tu.Name(i, tc.input, tc.maxLen), func(t *testing.T) { - got := stringz.TrimLenMiddle(tc.input, tc.maxLen) + got := stringz.EllipsifyASCII(tc.input, tc.maxLen) require.True(t, len(got) <= tc.maxLen) require.Equal(t, tc.want, got) }) diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 7534b2886..5d4c8d84e 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -2,6 +2,7 @@ package driver import ( "context" + "github.com/neilotoole/sq/libsq/core/progress" "log/slog" "path/filepath" "strings" @@ -233,7 +234,15 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, return nil, err } - if err = lock.Lock(ctx, time.Second*5); err != nil { + lockTimeout := source.OptCacheLockTimeout.Get(options.FromContext(ctx)) + bar := progress.FromContext(ctx).NewTimeoutWaiter( + src.Handle+": acquire lock", + time.Now().Add(lockTimeout), + ) + + err = lock.Lock(ctx, lockTimeout) + bar.Stop() + if err != nil { return nil, errz.Wrap(err, src.Handle+": acquire cache lock") } diff --git a/libsq/driver/ingest.go b/libsq/driver/ingest.go index 5b415a0b1..8785dd644 100644 --- a/libsq/driver/ingest.go +++ b/libsq/driver/ingest.go @@ -26,6 +26,8 @@ to detect the header.`, ) // OptIngestCache specifies whether ingested data is cached or not. +// +// REVISIT: Maybe rename ingest.cache simply to "cache"? var OptIngestCache = options.NewBool( "ingest.cache", "", diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 4d7c995c0..5d4954579 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" @@ -13,6 +14,21 @@ import ( "github.com/neilotoole/sq/libsq/core/stringz" ) +// OptCacheLockTimeout is the time allowed to acquire cache lock. +// +// See also: [driver.OptIngestCache]. +var OptCacheLockTimeout = options.NewDuration( + "cache.lock.timeout", + "", + 0, + time.Second*5, + "Wait timeout to acquire cache lock", + `Wait timeout to acquire cache lock. During this period, retry will occur +if the lock is already held by another process. If zero, no retry occurs.`, + options.TagSource, + options.TagSQL, +) + // CacheDirFor gets the cache dir for handle. It is not guaranteed // that the returned dir exists or is accessible. func (fs *Files) CacheDirFor(src *Source) (dir string, err error) { diff --git a/testh/tu/tutil.go b/testh/tu/tutil.go index 543284df0..26d7e50c3 100644 --- a/testh/tu/tutil.go +++ b/testh/tu/tutil.go @@ -208,7 +208,7 @@ func Name(args ...any) string { s = strings.ReplaceAll(s, "/", "_") s = strings.ReplaceAll(s, ":", "_") s = strings.ReplaceAll(s, `\`, "_") - s = stringz.TrimLenMiddle(s, 28) // we don't want it to be too long + s = stringz.EllipsifyASCII(s, 28) // we don't want it to be too long parts = append(parts, s) } From c48d4226be5b557cc7e84340a53b1b9d8156c408 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 10 Dec 2023 09:33:43 -0700 Subject: [PATCH 079/195] wip: more progress --- cli/cli.go | 2 + cli/run.go | 2 + libsq/core/progress/progress.go | 79 ++++++++++++++++++++++----------- 3 files changed, 56 insertions(+), 27 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index 6fd15e513..60609787d 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -168,9 +168,11 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { } } + rootCmd.SetContext(ctx) // Execute rootCmd; cobra will find the appropriate // sub-command, and ultimately execute that command. err = rootCmd.ExecuteContext(ctx) + log.Warn("Closing run", lga.Err, err) lg.WarnIfCloseError(log, "Problem closing run", ru) if err != nil { ctx2 := rootCmd.Context() // FIXME: delete diff --git a/cli/run.go b/cli/run.go index c4e92c384..3554981b5 100644 --- a/cli/run.go +++ b/cli/run.go @@ -84,7 +84,9 @@ func newRun(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args log, logHandler, logCloser, logErr := defaultLogging(ctx, args, ru.Config) ru.Cleanup = cleanup.New() + // FIXME: re-enable log closing ru.LogCloser = logCloser + _ = logCloser if logErr != nil { stderrLog, h := stderrLogger() _ = logbuf.Flush(ctx, h) diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 92a6b710f..38dca69ff 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -75,17 +75,17 @@ func FromContext(ctx context.Context) *Progress { // The Progress is lazily initialized, and thus the delay clock doesn't // start ticking until the first call to one of the Progress.NewX methods. func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors) *Progress { - lg.FromContext(ctx).Debug("New progress widget", "delay", delay) - - var cancelFn context.CancelFunc - ctx, cancelFn = context.WithCancel(ctx) + log := lg.FromContext(ctx) + log.Debug("New progress widget", "delay", delay) if colors == nil { colors = DefaultColors() } + pCtx, cancelFn := context.WithCancel(lg.NewContext(context.Background(), log)) + p := &Progress{ - ctx: ctx, + ctx: pCtx, mu: &sync.Mutex{}, colors: colors, cancelFn: cancelFn, @@ -93,6 +93,13 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors delay: delay, } + go func() { + <-ctx.Done() + lg.FromContext(pCtx).Warn("Main context canceled") + p.Stop() + lg.FromContext(pCtx).Warn("Main context trigger returned") + }() + p.pcInit = func() { opts := []mpb.ContainerOption{ mpb.WithOutput(out), @@ -168,10 +175,13 @@ func (p *Progress) Stop() { return } + lg.FromContext(p.ctx).Warn("Stopping progress widget") p.mu.Lock() p.doStop() + <-p.ctx.Done() p.mu.Unlock() - lg.FromContext(p.ctx).Debug("Stopped progress widget") + + lg.FromContext(p.ctx).Warn("Stopped progress widget") } // doStop is probably needlessly complex, but at the time it was written, @@ -185,6 +195,7 @@ func (p *Progress) doStop() { p.stopped = true if p.pc == nil { + p.pcInit = nil p.cancelFn() return } @@ -201,15 +212,21 @@ func (p *Progress) doStop() { // b.bar.Wait() is invoked by b.doStop(). This may be completely // unnecessary, but it doesn't seem to hurt. if b.bar != nil { - b.bar.Abort(true) + if !b.bar.Aborted() { + b.bar.Abort(true) + } + } } for _, b := range p.bars { b.doStop() + <-b.barStopped } + lg.FromContext(p.ctx).Warn("progress: p.pc.Wait()") p.pc.Wait() + lg.FromContext(p.ctx).Warn("progress: p.pc.Wait() DONE") // Important: we must call cancelFn after pc.Wait() or the bars // may not be removed from the terminal. p.cancelFn() @@ -252,11 +269,19 @@ func (p *Progress) newBar(msg string, total int64, p: p, incrStash: &atomic.Int64{}, initBarOnce: &sync.Once{}, + barStopOnce: &sync.Once{}, + barStopped: make(chan struct{}), } b.initBar = func() { - if b.stopped || p.stopped { + if p.stopped { + return + } + select { + case <-b.barStopped: return + default: } + b.bar = p.pc.New(total, style, mpb.BarWidth(barWidth), @@ -297,13 +322,14 @@ type Bar struct { initBarOnce *sync.Once initBar func() + barStopOnce *sync.Once + barStopped chan struct{} + delayCh <-chan struct{} // incrStash holds the increment count until the // bar is fully initialized. incrStash *atomic.Int64 - - stopped bool } // IncrBy increments progress by amount of n. It is safe to @@ -316,11 +342,13 @@ func (b *Bar) IncrBy(n int) { b.p.mu.Lock() defer b.p.mu.Unlock() - if b.stopped || b.p.stopped { + if b.p.stopped { return } select { + case <-b.barStopped: + return case <-b.p.ctx.Done(): return case <-b.delayCh: @@ -345,7 +373,7 @@ func (b *Bar) Stop() { defer b.p.mu.Unlock() b.doStop() - lg.FromContext(b.p.ctx).Debug("Stopped progress bar") + <-b.barStopped } func (b *Bar) doStop() { @@ -353,23 +381,20 @@ func (b *Bar) doStop() { return } - if b.bar == nil { - b.stopped = true - return - } - - if b.stopped { - return - } + b.barStopOnce.Do(func() { + if b.bar == nil { + close(b.barStopped) + return + } - //if !b.stopped { - // b.bar.Abort(true) - //} - b.bar.SetTotal(-1, true) - b.bar.Abort(true) - b.stopped = true + if !b.bar.Aborted() && !b.bar.Completed() { + b.bar.Abort(true) + } - b.bar.Wait() + b.bar.Wait() + close(b.barStopped) + lg.FromContext(b.p.ctx).Debug("Stopped progress bar") + }) } // barRenderDelay returns a channel that will be closed after d, From 45769bc30a25b70a701f24ab8c01791280443fc5 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 10 Dec 2023 11:16:30 -0700 Subject: [PATCH 080/195] wip: more progress --- cli/error.go | 9 +- cli/output.go | 11 +- cli/run/run.go | 3 +- libsq/core/progress/io.go | 4 +- libsq/core/progress/progress.go | 164 ++++++++++++++--------------- libsq/core/stringz/stringz_test.go | 7 +- libsq/driver/grips.go | 3 +- 7 files changed, 105 insertions(+), 96 deletions(-) diff --git a/cli/error.go b/cli/error.go index c6359b02d..a8db0a173 100644 --- a/cli/error.go +++ b/cli/error.go @@ -4,10 +4,11 @@ import ( "context" "errors" "fmt" + "os" + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/spf13/cobra" "github.com/spf13/pflag" - "os" "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/output/format" @@ -26,9 +27,9 @@ import ( func printError(ctx context.Context, ru *run.Run, err error) { log := lg.FromContext(ctx) log.Warn("printError called", lga.Err, err) // FIXME: delete - //debug.PrintStack() - //stack := errz.Stack(err) - //fmt.Fprintln(ru.Out, "printError stack", "stack", stack) + // debug.PrintStack() + // stack := errz.Stack(err) + // fmt.Fprintln(ru.Out, "printError stack", "stack", stack) if err == nil { log.Warn("printError called with nil error") return diff --git a/cli/output.go b/cli/output.go index 9f1eec5a7..b32a13ac4 100644 --- a/cli/output.go +++ b/cli/output.go @@ -2,12 +2,13 @@ package cli import ( "fmt" - "github.com/neilotoole/sq/libsq/core/cleanup" "io" "os" "strings" "time" + "github.com/neilotoole/sq/libsq/core/cleanup" + "github.com/fatih/color" colorable "github.com/mattn/go-colorable" wordwrap "github.com/mitchellh/go-wordwrap" @@ -262,7 +263,9 @@ Note that this option is no-op if the rendered value is not an integer. // newWriters returns an output.Writers instance configured per defaults and/or // flags from cmd. The returned out2/errOut2 values may differ // from the out/errOut args (e.g. decorated to support colorization). -func newWriters(cmd *cobra.Command, clnup *cleanup.Cleanup, o options.Options, out, errOut io.Writer) (w *output.Writers, out2, errOut2 io.Writer) { +func newWriters(cmd *cobra.Command, clnup *cleanup.Cleanup, o options.Options, + out, errOut io.Writer, +) (w *output.Writers, out2, errOut2 io.Writer) { var pr *output.Printing pr, out2, errOut2 = getPrinting(cmd, clnup, o, out, errOut) log := logFrom(cmd) @@ -375,7 +378,9 @@ func getRecordWriterFunc(f format.Format) output.NewRecordWriterFunc { // be absolutely bulletproof, as it's called by all commands, as well // as by the error handling mechanism. So, be sure to always check // for nil cmd, nil cmd.Context, etc. -func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Options, out, errOut io.Writer) (pr *output.Printing, out2, errOut2 io.Writer) { +func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Options, + out, errOut io.Writer, +) (pr *output.Printing, out2, errOut2 io.Writer) { pr = output.NewPrinting() pr.FormatDatetime = timez.FormatFunc(OptDatetimeFormat.Get(opts)) diff --git a/cli/run/run.go b/cli/run/run.go index 9d72278f1..02eac2839 100644 --- a/cli/run/run.go +++ b/cli/run/run.go @@ -4,10 +4,11 @@ package run import ( "context" - "github.com/neilotoole/sq/libsq/core/lg" "io" "os" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/config" diff --git a/libsq/core/progress/io.go b/libsq/core/progress/io.go index ba1832ee5..256625519 100644 --- a/libsq/core/progress/io.go +++ b/libsq/core/progress/io.go @@ -71,7 +71,7 @@ func (w *progWriter) Write(p []byte) (n int, err error) { w.b.Stop() return 0, w.ctx.Err() case <-w.delayCh: - w.b.initBarOnce.Do(w.b.initBar) + w.b.barInitOnce.Do(w.b.barInitFn) default: } @@ -167,7 +167,7 @@ func (r *progReader) Read(p []byte) (n int, err error) { r.b.Stop() return 0, r.ctx.Err() case <-r.delayCh: - r.b.initBarOnce.Do(r.b.initBar) + r.b.barInitOnce.Do(r.b.barInitFn) default: } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 38dca69ff..ccb1dbbe2 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -85,12 +85,14 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors pCtx, cancelFn := context.WithCancel(lg.NewContext(context.Background(), log)) p := &Progress{ - ctx: pCtx, - mu: &sync.Mutex{}, - colors: colors, - cancelFn: cancelFn, - bars: make([]*Bar, 0), - delay: delay, + ctx: pCtx, + mu: &sync.Mutex{}, + colors: colors, + cancelFn: cancelFn, + bars: make([]*Bar, 0), + delay: delay, + stoppedCh: make(chan struct{}), + stopOnce: &sync.Once{}, } go func() { @@ -100,7 +102,7 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors lg.FromContext(pCtx).Warn("Main context trigger returned") }() - p.pcInit = func() { + p.pcInitFn = func() { opts := []mpb.ContainerOption{ mpb.WithOutput(out), mpb.WithWidth(boxWidth), @@ -109,10 +111,10 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors } p.pc = mpb.NewWithContext(ctx, opts...) - p.pcInit = nil + p.pcInitFn = nil } - p.pcInit() + p.pcInitFn() return p } @@ -146,20 +148,20 @@ type Progress struct { cancelFn context.CancelFunc // pc is the underlying progress container. It is lazily initialized - // by pcInit. Any method that accesses pc must be certain that - // pcInit has been called. + // by pcInitFn. Any method that accesses pc must be certain that + // pcInitFn has been called. pc *mpb.Progress - // pcInit is the func that lazily initializes pc. - // FIXME: Do we even need the lazily initialized pc now? - pcInit func() + // pcInitFn is the func that lazily initializes pc. + pcInitFn func() // delay is the duration to wait before rendering a progress bar. // This value is used for each bar created by this Progress. delay time.Duration - // stopped is set to true when Stop is called. - stopped bool + // stoppedCh is closed when the progress widget is stopped. + stoppedCh chan struct{} + stopOnce *sync.Once colors *Colors @@ -178,7 +180,6 @@ func (p *Progress) Stop() { lg.FromContext(p.ctx).Warn("Stopping progress widget") p.mu.Lock() p.doStop() - <-p.ctx.Done() p.mu.Unlock() lg.FromContext(p.ctx).Warn("Stopped progress widget") @@ -188,48 +189,48 @@ func (p *Progress) Stop() { // there was a bug in the mpb package (to do with delayed render and abort), // and so was created an extra-paranoid workaround. func (p *Progress) doStop() { - if p.stopped { - return - } + p.stopOnce.Do(func() { + defer close(p.stoppedCh) - p.stopped = true - - if p.pc == nil { - p.pcInit = nil - p.cancelFn() - return - } + if p.pc == nil { + p.pcInitFn = nil + p.cancelFn() + return + } - if len(p.bars) == 0 { - // p.pc.Wait() FIXME: Does this need to happen - p.cancelFn() - return - } + if len(p.bars) == 0 { + // p.pc.Wait() FIXME: Does this need to happen + p.cancelFn() + return + } - for _, b := range p.bars { - // We abort each of the bars here, before we call b.doStop() below. - // In theory, this gives the bar abortion process a head start before - // b.bar.Wait() is invoked by b.doStop(). This may be completely - // unnecessary, but it doesn't seem to hurt. - if b.bar != nil { - if !b.bar.Aborted() { - b.bar.Abort(true) + for _, b := range p.bars { + // We abort each of the bars here, before we call b.doStop() below. + // In theory, this gives the bar abortion process a head start before + // b.bar.Wait() is invoked by b.doStop(). This may be completely + // unnecessary, but it doesn't seem to hurt. + if b.bar != nil { + if !b.bar.Aborted() { + b.bar.Abort(true) + } } + } + for _, b := range p.bars { + b.doStop() + <-b.barStoppedCh // Wait for bar to stop } - } - for _, b := range p.bars { - b.doStop() - <-b.barStopped - } + lg.FromContext(p.ctx).Warn("progress: p.pc.Wait()") + p.pc.Wait() + lg.FromContext(p.ctx).Warn("progress: p.pc.Wait() DONE") + // Important: we must call cancelFn after pc.Wait() or the bars + // may not be removed from the terminal. + p.cancelFn() + }) - lg.FromContext(p.ctx).Warn("progress: p.pc.Wait()") - p.pc.Wait() - lg.FromContext(p.ctx).Warn("progress: p.pc.Wait() DONE") - // Important: we must call cancelFn after pc.Wait() or the bars - // may not be removed from the terminal. - p.cancelFn() + <-p.stoppedCh + <-p.ctx.Done() } // newBar returns a new Bar. This function must only be called from @@ -241,16 +242,18 @@ func (p *Progress) newBar(msg string, total int64, return nil } - lg.FromContext(p.ctx).Debug("New bar", "msg", msg, "total", total) - select { + case <-p.stoppedCh: + return nil case <-p.ctx.Done(): return nil default: } + lg.FromContext(p.ctx).Debug("New bar", "msg", msg, "total", total) + if p.pc == nil { - p.pcInit() + p.pcInitFn() } if total < 0 { @@ -266,18 +269,17 @@ func (p *Progress) newBar(msg string, total int64, } b := &Bar{ - p: p, - incrStash: &atomic.Int64{}, - initBarOnce: &sync.Once{}, - barStopOnce: &sync.Once{}, - barStopped: make(chan struct{}), + p: p, + incrStash: &atomic.Int64{}, + barInitOnce: &sync.Once{}, + barStopOnce: &sync.Once{}, + barStoppedCh: make(chan struct{}), } - b.initBar = func() { - if p.stopped { - return - } + b.barInitFn = func() { select { - case <-b.barStopped: + case <-p.stoppedCh: + return + case <-b.barStoppedCh: return default: } @@ -306,7 +308,7 @@ func (p *Progress) newBar(msg string, total int64, // the bar is complete, the caller should invoke [Bar.Stop]. All // methods are safe to call on a nil Bar. type Bar struct { - // bar is nil until barInitOnce.Do(initBar) is called + // bar is nil until barInitOnce.Do(barInitFn) is called bar *mpb.Bar // p is never nil p *Progress @@ -319,11 +321,11 @@ type Bar struct { // Until that bug is fixed, the Bar is lazily initialized // after the render delay expires. - initBarOnce *sync.Once - initBar func() + barInitOnce *sync.Once + barInitFn func() - barStopOnce *sync.Once - barStopped chan struct{} + barStopOnce *sync.Once + barStoppedCh chan struct{} delayCh <-chan struct{} @@ -342,17 +344,15 @@ func (b *Bar) IncrBy(n int) { b.p.mu.Lock() defer b.p.mu.Unlock() - if b.p.stopped { - return - } - select { - case <-b.barStopped: + case <-b.p.stoppedCh: + return + case <-b.barStoppedCh: return case <-b.p.ctx.Done(): return case <-b.delayCh: - b.initBarOnce.Do(b.initBar) + b.barInitOnce.Do(b.barInitFn) if b.bar != nil { b.bar.IncrBy(n) } @@ -373,7 +373,7 @@ func (b *Bar) Stop() { defer b.p.mu.Unlock() b.doStop() - <-b.barStopped + <-b.barStoppedCh } func (b *Bar) doStop() { @@ -383,7 +383,7 @@ func (b *Bar) doStop() { b.barStopOnce.Do(func() { if b.bar == nil { - close(b.barStopped) + close(b.barStoppedCh) return } @@ -392,7 +392,7 @@ func (b *Bar) doStop() { } b.bar.Wait() - close(b.barStopped) + close(b.barStoppedCh) lg.FromContext(b.p.ctx).Debug("Stopped progress bar") }) } @@ -400,14 +400,14 @@ func (b *Bar) doStop() { // barRenderDelay returns a channel that will be closed after d, // at which point b will be initialized. func barRenderDelay(b *Bar, d time.Duration) <-chan struct{} { - ch := make(chan struct{}) + delayCh := make(chan struct{}) t := time.NewTimer(d) go func() { - defer close(ch) + defer close(delayCh) defer t.Stop() <-t.C - b.initBarOnce.Do(b.initBar) + b.barInitOnce.Do(b.barInitFn) }() - return ch + return delayCh } diff --git a/libsq/core/stringz/stringz_test.go b/libsq/core/stringz/stringz_test.go index 859c17f39..208c79960 100644 --- a/libsq/core/stringz/stringz_test.go +++ b/libsq/core/stringz/stringz_test.go @@ -1,13 +1,14 @@ package stringz_test import ( + "strconv" + "strings" + "testing" + "github.com/samber/lo" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "strconv" - "strings" - "testing" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/testh/tu" diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 5d4c8d84e..2d389a44c 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -2,13 +2,14 @@ package driver import ( "context" - "github.com/neilotoole/sq/libsq/core/progress" "log/slog" "path/filepath" "strings" "sync" "time" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" From 3f829fcd4baf14be5272e1618a4ce9b8a03976d7 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Mon, 11 Dec 2023 06:44:09 -0700 Subject: [PATCH 081/195] Pausing work on progress stuff; works reasonably well --- cli/output.go | 8 +++- libsq/core/progress/progress.go | 85 ++++++++++++++++++++++----------- 2 files changed, 63 insertions(+), 30 deletions(-) diff --git a/cli/output.go b/cli/output.go index b32a13ac4..2692a4cf2 100644 --- a/cli/output.go +++ b/cli/output.go @@ -426,7 +426,7 @@ func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Option progColors.EnableColor(false) ctx := cmd.Context() renderDelay := OptProgressDelay.Get(opts) - pb := progress.New(ctx, errOut, renderDelay, progColors) + pb := progress.New(ctx, errOut2, renderDelay, progColors) clnup.Add(pb.Stop) // On first write to stdout, we remove the progress widget. out2 = ioz.NotifyOnceWriter(out2, pb.Stop) @@ -473,6 +473,12 @@ func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Option lg.FromContext(ctx).Debug("Output stream is being written to; removing progress widget") pb.Stop() }) + + // On first write to stderr, we remove the progress widget. + errOut2 = ioz.NotifyOnceWriter(errOut2, func() { + lg.FromContext(ctx).Debug("Error stream is being written to; removing progress widget") + pb.Stop() + }) // FIXME: delete cmd.SetContext(progress.NewContext(ctx, pb)) } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index ccb1dbbe2..329d5007f 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -82,36 +82,55 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors colors = DefaultColors() } - pCtx, cancelFn := context.WithCancel(lg.NewContext(context.Background(), log)) - p := &Progress{ - ctx: pCtx, mu: &sync.Mutex{}, colors: colors, - cancelFn: cancelFn, bars: make([]*Bar, 0), delay: delay, stoppedCh: make(chan struct{}), stopOnce: &sync.Once{}, + refreshCh: make(chan any, 100), } + // Note that p.ctx is not the same as the arg ctx. This is a bit of a hack + // to ensure that p.Stop gets called when ctx is cancelled, but before + // the p.pc learns that its context is cancelled. This was done in an attempt + // to clean up the progress bars before the main context is cancelled (i.e. + // to remove bars when the user hits Ctrl-C). Alas, it's not working as + // hoped in that scenario. + p.ctx, p.cancelFn = context.WithCancel(lg.NewContext(context.Background(), log)) go func() { <-ctx.Done() - lg.FromContext(pCtx).Warn("Main context canceled") + log.Debug("Stopping via go ctx done") p.Stop() - lg.FromContext(pCtx).Warn("Main context trigger returned") + <-p.stoppedCh + <-p.ctx.Done() }() p.pcInitFn = func() { opts := []mpb.ContainerOption{ mpb.WithOutput(out), mpb.WithWidth(boxWidth), - mpb.WithRefreshRate(refreshRate), - mpb.WithAutoRefresh(), // Needed for color in Windows, apparently + // mpb.WithRefreshRate(refreshRate), + mpb.WithManualRefresh(p.refreshCh), + // mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } p.pc = mpb.NewWithContext(ctx, opts...) p.pcInitFn = nil + go func() { + for { + select { + case <-p.stoppedCh: + return + case <-p.ctx.Done(): + return + default: + p.refreshCh <- time.Now() + time.Sleep(refreshRate) + } + } + }() } p.pcInitFn() @@ -144,6 +163,13 @@ type Progress struct { // mu guards ALL public methods. mu *sync.Mutex + // stoppedCh is closed when the progress widget is stopped. + // This somewhat duplicates <-p.ctx.Done()... maybe it can be removed? + stoppedCh chan struct{} + stopOnce *sync.Once + + refreshCh chan any + ctx context.Context cancelFn context.CancelFunc @@ -159,10 +185,6 @@ type Progress struct { // This value is used for each bar created by this Progress. delay time.Duration - // stoppedCh is closed when the progress widget is stopped. - stoppedCh chan struct{} - stopOnce *sync.Once - colors *Colors // bars contains all bars that have been created on this Progress. @@ -177,12 +199,10 @@ func (p *Progress) Stop() { return } - lg.FromContext(p.ctx).Warn("Stopping progress widget") p.mu.Lock() p.doStop() + <-p.stoppedCh p.mu.Unlock() - - lg.FromContext(p.ctx).Warn("Stopped progress widget") } // doStop is probably needlessly complex, but at the time it was written, @@ -190,16 +210,19 @@ func (p *Progress) Stop() { // and so was created an extra-paranoid workaround. func (p *Progress) doStop() { p.stopOnce.Do(func() { - defer close(p.stoppedCh) - + p.pcInitFn = nil + lg.FromContext(p.ctx).Warn("Stopping progress widget") + defer lg.FromContext(p.ctx).Warn("Stopped progress widget") if p.pc == nil { - p.pcInitFn = nil + close(p.stoppedCh) + close(p.refreshCh) p.cancelFn() return } if len(p.bars) == 0 { - // p.pc.Wait() FIXME: Does this need to happen + close(p.stoppedCh) + close(p.refreshCh) p.cancelFn() return } @@ -210,9 +233,8 @@ func (p *Progress) doStop() { // b.bar.Wait() is invoked by b.doStop(). This may be completely // unnecessary, but it doesn't seem to hurt. if b.bar != nil { - if !b.bar.Aborted() { - b.bar.Abort(true) - } + b.bar.SetTotal(-1, true) + b.bar.Abort(true) } } @@ -221,9 +243,10 @@ func (p *Progress) doStop() { <-b.barStoppedCh // Wait for bar to stop } - lg.FromContext(p.ctx).Warn("progress: p.pc.Wait()") + p.refreshCh <- time.Now() + close(p.stoppedCh) + close(p.refreshCh) p.pc.Wait() - lg.FromContext(p.ctx).Warn("progress: p.pc.Wait() DONE") // Important: we must call cancelFn after pc.Wait() or the bars // may not be removed from the terminal. p.cancelFn() @@ -294,7 +317,8 @@ func (p *Progress) newBar(msg string, total int64, mpb.BarRemoveOnComplete(), ) b.bar.IncrBy(int(b.incrStash.Load())) - b.incrStash.Store(0) + b.incrStash = nil + // b.incrStash.Store(0) } b.delayCh = barRenderDelay(b, p.delay) @@ -382,16 +406,19 @@ func (b *Bar) doStop() { } b.barStopOnce.Do(func() { + lg.FromContext(b.p.ctx).Debug("Stopping progress bar") if b.bar == nil { close(b.barStoppedCh) return } - if !b.bar.Aborted() && !b.bar.Completed() { - b.bar.Abort(true) - } - + // We *probably* only need to call b.bar.Abort() here? + b.bar.SetTotal(-1, true) + b.bar.Abort(true) + b.p.refreshCh <- time.Now() b.bar.Wait() + b.p.refreshCh <- time.Now() + close(b.barStoppedCh) lg.FromContext(b.p.ctx).Debug("Stopped progress bar") }) From ed86c4b51ac7d2ec1742887249416a2387592070 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Mon, 11 Dec 2023 11:26:55 -0700 Subject: [PATCH 082/195] wip: downloader --- libsq/core/ioz/checksum/checksum.go | 45 ++++++++++++--- libsq/source/download.go | 88 +++++++++++++++++++++++------ libsq/source/download_test.go | 52 +++++++++++++---- 3 files changed, 151 insertions(+), 34 deletions(-) diff --git a/libsq/core/ioz/checksum/checksum.go b/libsq/core/ioz/checksum/checksum.go index a206599a6..6f5136df4 100644 --- a/libsq/core/ioz/checksum/checksum.go +++ b/libsq/core/ioz/checksum/checksum.go @@ -149,6 +149,29 @@ func ForHTTPHeader(u string, header http.Header) Checksum { // request URL and the contents of the response's header. If the header // contains an Etag, that is used as the primary element. Otherwise, // other values such as Content-Length and Last-Modified are considered. +// +// There's some trickiness with Etag. Note that by default, stdlib http.Client +// will set sneakily set the "Accept-Encoding: gzip" header on GET requests. +// However, this doesn't happen for HEAD requests. So, comparing a GET and HEAD +// response for the same URL may result in different checksums, because the +// server will likely return a different Etag for the gzipped response. +// +// # With gzip +// Etag: "069dbf690a12d5eb853feb8e04aeb49e-ssl-df" +// +// # Without gzip +// Etag: "069dbf690a12d5eb853feb8e04aeb49e-ssl" +// +// Note the "-ssl-df" suffix on the gzipped response. The "df" suffix is +// for "deflate". +// +// The solution here might be to always explicitly set the gzip header on all +// requests. However, when gzip is not explicitly set, the stdlib client +// transparently handles gzip compression, including on the body read end. So, +// ideally, we wouldn't change that part, so that we don't have to code for +// both compressed and uncompressed responses. +// +// Our hack for now it to trim the "-df" suffix from the Etag. func ForHTTPResponse(resp *http.Response) Checksum { if resp == nil { return "" @@ -156,20 +179,28 @@ func ForHTTPResponse(resp *http.Response) Checksum { buf := bytes.Buffer{} if resp.Request != nil && resp.Request.URL != nil { - buf.WriteString(resp.Request.URL.String()) + buf.WriteString(resp.Request.URL.String() + "\n") } + buf.WriteString(strconv.Itoa(int(resp.ContentLength)) + "\n") header := resp.Header if header != nil { - etag := header.Get("Etag") + buf.WriteString(header.Get("Content-Encoding") + "\n") + etag := strings.TrimSpace(header.Get("Etag")) if etag != "" { - buf.WriteString(etag) + etag = strings.TrimSuffix(etag, "-df") + buf.WriteString(etag + "\n") } else { - buf.WriteString(header.Get("Content-Type")) - buf.WriteString(header.Get("Content-Disposition")) - buf.WriteString(header.Get("Content-Length")) - buf.WriteString(header.Get("Last-Modified")) + buf.WriteString(header.Get("Content-Type") + "\n") + buf.WriteString(header.Get("Content-Disposition") + "\n") + buf.WriteString(header.Get("Content-Length") + "\n") + buf.WriteString(header.Get("Last-Modified") + "\n") } } + s := buf.String() + _ = s + + fmt.Printf("\n\n%s\n\n", s) + return Checksum(Hash(buf.Bytes())) } diff --git a/libsq/source/download.go b/libsq/source/download.go index 52b3d3663..64d377f84 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -51,15 +51,16 @@ var OptHTTPSkipVerify = options.NewBool( "Skip HTTPS TLS verification. Useful when downloading against self-signed certs.", ) -func newDownloader(c *http.Client, cacheDir, url string) *downloader { +// newDownloader creates a new downloader using cacheDir for the given url. +func newDownloader(c *http.Client, cacheDir, dlURL string) *downloader { return &downloader{ c: c, cacheDir: cacheDir, - url: url, + url: dlURL, } } -// download is a helper for getting file contents from a URL, +// downloader is a helper for getting file contents from a URL, // and caching the file locally. The structure of cacheDir // is as follows: // @@ -74,15 +75,16 @@ func newDownloader(c *http.Client, cacheDir, url string) *downloader { // // - pid.lock is a lock file used to ensure that only one // process is downloading the file at a time. +// FIXME: are we using pid.lock, or will we share the parent cache lock? // // - header.txt is a dump of the HTTP response header, included for // debugging convenience. // // - checksum.txt contains a checksum:key pair, where the checksum is -// calculated using checksum.ForHTTPHeader, and the key is the path -// to the downloaded file, e.g. "dl/data.csv". +// calculated using checksum.ForHTTPResponse, and the key is the path +// to the downloaded file, e.g. "dl/actor.csv". // -// 67a47a0...a53e3e28154 dl/actor.csv +// 67a47a0 dl/actor.csv // // - The file is downloaded to dl/ instead of into the root // of cache dir, just to avoid the (remote) possibility of a name @@ -117,8 +119,10 @@ func (d *downloader) headerFile() string { return filepath.Join(d.cacheDir, "header.txt") } -// Download downloads the file at the URL to the download dir, and also writes -// the file to dest, and returns the file path of the downloaded file. +// Download downloads the file at the URL to the download dir, creating the +// checksum file on completion, and also writes the file to dest, and returns +// the file path of the downloaded file. +// // It is the caller's responsibility to close dest. If an appropriate file name // cannot be determined from the HTTP response, the file is named "download". // If the download fails at any stage, the download file is removed, but written @@ -152,6 +156,7 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 if err != nil { return written, "", errz.Wrapf(err, "download new request failed for: %s", d.url) } + //setDefaultHTTPRequestHeaders(req) resp, err := d.c.Do(req) if err != nil { @@ -198,6 +203,13 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 return written, "", errz.Wrapf(err, "failed to close download file: %s", fp) } + if resp.ContentLength == -1 { + // Sometimes the response won't have the content-length set, but we know + // it via the number of bytes read from the body. We explicitly set + // it here, because checksum.ForHTTPResponse uses it. + resp.ContentLength = written + } + sum := checksum.ForHTTPResponse(resp) if err = checksum.WriteFile(d.checksumFile(), sum, filepath.Join("dl", filename)); err != nil { lg.WarnIfFuncError(log, lgm.RemoveFile, func() error { return errz.Err(os.Remove(fp)) }) @@ -207,6 +219,11 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 return written, fp, nil } +func setDefaultHTTPRequestHeaders(req *http.Request) { + req.Header.Set("User-Agent", "sq") // FIXME: this should be set on the http.Client + req.Header.Set("Accept-Encoding", "gzip") +} + func (d *downloader) writeHeaderFile(resp *http.Response) error { b, err := httputil.DumpResponse(resp, false) if err != nil { @@ -227,6 +244,23 @@ func (d *downloader) writeHeaderFile(resp *http.Response) error { return nil } +// ClearCache clears the cache dir. +func (d *downloader) ClearCache(ctx context.Context) error { + d.mu.Lock() + defer d.mu.Unlock() + + log := d.log(lg.FromContext(ctx)) + if err := os.RemoveAll(d.cacheDir); err != nil { + log.Error("Failed to clear cache dir", lga.Dir, d.cacheDir, lga.Err, err) + return errz.Wrapf(err, "failed to clear cache dir: %s", d.cacheDir) + } + + log.Info("Cleared cache dir", lga.Dir, d.cacheDir) + return nil +} + +// Cached returns true if the file is cached locally, and if so, also returns +// the checksum and file path of the cached file. func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum, fp string) { d.mu.Lock() defer d.mu.Unlock() @@ -244,7 +278,7 @@ func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum } if _, err = os.Stat(d.checksumFile()); err != nil { - log.Debug("not cached: can't stat download checksum file") + log.Debug("not cached: can't stat download checksum file", lga.File, d.checksumFile()) return false, "", "" } @@ -266,8 +300,7 @@ func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum return false, "", "" } - downloadFile := filepath.Join(dlDir, key) - + downloadFile := filepath.Join(d.cacheDir, key) if _, err = os.Stat(downloadFile); err != nil { log.Debug("not cached: can't stat download file referenced in checksum file", lga.File, key) return false, "", "" @@ -277,15 +310,37 @@ func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum return true, sum, downloadFile } -// fetchHTTPHeader fetches the HTTP header for u. First HEAD is used, and +// CachedIsCurrent returns true if the file is cached locally and if its +// stored checksum matches the checksum of the remote file. +func (d *downloader) CachedIsCurrent(ctx context.Context) (ok bool, err error) { + ok, sum, _ := d.Cached(ctx) + if !ok { + return false, errz.Errorf("not cached: %s", d.url) + } + + resp, err := fetchHTTPResponse(ctx, d.c, d.url) + if err != nil { + return false, errz.Wrap(err, "check remote header") + } + + remoteSum := checksum.ForHTTPResponse(resp) + if sum != remoteSum { + return false, nil + } + + return true, nil +} + +// fetchHTTPResponse fetches the HTTP response for u. First HEAD is tried, and // if that's not allowed (http.StatusMethodNotAllowed), then GET is used. -func fetchHTTPHeader(ctx context.Context, u string) (header http.Header, err error) { +func fetchHTTPResponse(ctx context.Context, c *http.Client, u string) (resp *http.Response, err error) { req, err := http.NewRequestWithContext(ctx, http.MethodHead, u, nil) if err != nil { return nil, errz.Err(err) } + setDefaultHTTPRequestHeaders(req) - resp, err := http.DefaultClient.Do(req) + resp, err = c.Do(req) if err != nil { return nil, errz.Err(err) } @@ -297,7 +352,7 @@ func fetchHTTPHeader(ctx context.Context, u string) (header http.Header, err err default: return nil, errz.Errorf("unexpected HTTP status (%s) for HEAD: %s", resp.Status, u) case http.StatusOK: - return resp.Header, nil + return resp, nil case http.StatusMethodNotAllowed: } @@ -309,6 +364,7 @@ func fetchHTTPHeader(ctx context.Context, u string) (header http.Header, err err if err != nil { return nil, errz.Err(err) } + //setDefaultHTTPRequestHeaders(req) resp, err = http.DefaultClient.Do(req) if err != nil { @@ -322,7 +378,7 @@ func fetchHTTPHeader(ctx context.Context, u string) (header http.Header, err err return nil, errz.Errorf("unexpected HTTP status (%s) for GET: %s", resp.Status, u) } - return resp.Header, nil + return resp, nil } func getRemoteChecksum(ctx context.Context, u string) (string, error) { diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go index 93b3504d4..07116b0a9 100644 --- a/libsq/source/download_test.go +++ b/libsq/source/download_test.go @@ -3,8 +3,11 @@ package source import ( "bytes" "context" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "net/http" "net/http/httptest" + "net/url" + "path" "path/filepath" "strconv" "testing" @@ -33,10 +36,11 @@ const ( urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" urlActorCSV = "https://sq.io/testdata/actor.csv" sizeActorCSV = int64(7641) + sizeGzipActorCSV = int64(1968) ) func TestFetchHTTPHeader_sqio(t *testing.T) { - header, err := fetchHTTPHeader(context.Background(), urlActorCSV) + header, err := fetchHTTPResponse(context.Background(), urlActorCSV) require.NoError(t, err) require.NotNil(t, header) @@ -45,25 +49,49 @@ func TestFetchHTTPHeader_sqio(t *testing.T) { func TestDownloader_Download(t *testing.T) { ctx := lg.NewContext(context.Background(), slogt.New(t)) + const dlURL = urlActorCSV + const wantContentLength = sizeActorCSV + u, err := url.Parse(dlURL) + require.NoError(t, err) + wantFilename := path.Base(u.Path) + require.Equal(t, "actor.csv", wantFilename) cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-1")) require.NoError(t, err) t.Logf("cacheDir: %s", cacheDir) - dl := newDownloader(http.DefaultClient, cacheDir, urlActorCSV) + + dl := newDownloader(http.DefaultClient, cacheDir, dlURL) + require.NoError(t, dl.ClearCache(ctx)) + buf := &bytes.Buffer{} - written, fp, err := dl.Download(ctx, buf) + written, cachedFp, err := dl.Download(ctx, buf) require.NoError(t, err) - require.FileExists(t, fp) - require.Equal(t, sizeActorCSV, written) - require.Equal(t, sizeActorCSV, int64(buf.Len())) + require.FileExists(t, cachedFp) + require.Equal(t, wantContentLength, written) + require.Equal(t, wantContentLength, int64(buf.Len())) s := tu.ReadFileToString(t, dl.headerFile()) - t.Logf("header.txt\n\n" + s) + t.Logf("header.txt\n\n" + s + "\n") s = tu.ReadFileToString(t, dl.checksumFile()) - t.Logf("checksum.txt\n\n" + s) + t.Logf("checksum.txt\n\n" + s + "\n") - // TODO + gotSums, err := checksum.ReadFile(dl.checksumFile()) + require.NoError(t, err) + + isCached, cachedSum, cachedFp := dl.Cached(ctx) + require.True(t, isCached) + wantKey := filepath.Join("dl", wantFilename) + wantFp, err := filepath.Abs(filepath.Join(dl.cacheDir, wantKey)) + require.NoError(t, err) + require.Equal(t, wantFp, cachedFp) + fileSum, ok := gotSums[wantKey] + require.True(t, ok) + require.Equal(t, cachedSum, fileSum) + + isCurrent, err := dl.CachedIsCurrent(ctx) + require.NoError(t, err) + require.True(t, isCurrent) } func TestFetchHTTPHeader_HEAD_fallback_GET(t *testing.T) { @@ -83,7 +111,9 @@ func TestFetchHTTPHeader_HEAD_fallback_GET(t *testing.T) { u := srvr.URL - header, err := fetchHTTPHeader(context.Background(), u) + resp, err := fetchHTTPResponse(context.Background(), u) assert.NoError(t, err) - assert.NotNil(t, header) + assert.NotNil(t, resp) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, len(b), int(resp.ContentLength)) } From 5eede1d8b2696dbe3818b03c554fb96c368946c3 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Mon, 11 Dec 2023 12:42:23 -0700 Subject: [PATCH 083/195] Pausing work on downloader; will try different impl --- libsq/source/download_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go index 07116b0a9..a9093478e 100644 --- a/libsq/source/download_test.go +++ b/libsq/source/download_test.go @@ -40,7 +40,7 @@ const ( ) func TestFetchHTTPHeader_sqio(t *testing.T) { - header, err := fetchHTTPResponse(context.Background(), urlActorCSV) + header, err := fetchHTTPResponse(context.Background(), http.DefaultClient, urlActorCSV) require.NoError(t, err) require.NotNil(t, header) @@ -111,7 +111,7 @@ func TestFetchHTTPHeader_HEAD_fallback_GET(t *testing.T) { u := srvr.URL - resp, err := fetchHTTPResponse(context.Background(), u) + resp, err := fetchHTTPResponse(context.Background(), http.DefaultClient, u) assert.NoError(t, err) assert.NotNil(t, resp) require.Equal(t, http.StatusOK, resp.StatusCode) From 86df48dd79ce1be9060e8dc12b17c6ccce4f904f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 04:46:43 -0700 Subject: [PATCH 084/195] Forked httpcache internally --- go.mod | 4 + go.sum | 8 + libsq/core/ioz/httpcache/LICENSE.txt | 7 + libsq/core/ioz/httpcache/README.md | 42 + .../core/ioz/httpcache/diskcache/diskcache.go | 63 + .../ioz/httpcache/diskcache/diskcache_test.go | 18 + libsq/core/ioz/httpcache/httpcache.go | 584 +++++++ libsq/core/ioz/httpcache/httpcache_test.go | 1475 +++++++++++++++++ libsq/core/ioz/httpcache/memcache/memcache.go | 62 + .../ioz/httpcache/memcache/memcache_test.go | 24 + libsq/core/ioz/httpcache/test/test.go | 36 + libsq/core/ioz/httpcache/test/test_test.go | 12 + 12 files changed, 2335 insertions(+) create mode 100644 libsq/core/ioz/httpcache/LICENSE.txt create mode 100644 libsq/core/ioz/httpcache/README.md create mode 100644 libsq/core/ioz/httpcache/diskcache/diskcache.go create mode 100644 libsq/core/ioz/httpcache/diskcache/diskcache_test.go create mode 100644 libsq/core/ioz/httpcache/httpcache.go create mode 100644 libsq/core/ioz/httpcache/httpcache_test.go create mode 100644 libsq/core/ioz/httpcache/memcache/memcache.go create mode 100644 libsq/core/ioz/httpcache/memcache/memcache_test.go create mode 100644 libsq/core/ioz/httpcache/test/test.go create mode 100644 libsq/core/ioz/httpcache/test/test_test.go diff --git a/go.mod b/go.mod index d69ab6dc7..789748bec 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( github.com/a8m/tree v0.0.0-20230208161321-36ae24ddad15 github.com/alessio/shellescape v1.4.2 github.com/antlr4-go/antlr/v4 v4.13.0 + github.com/bitcomplete/httpcache v0.0.0-20220528171057-1f4a71bbffc5 github.com/c2h5oh/datasize v0.0.0-20220606134207-859f65c6625b github.com/dustin/go-humanize v1.0.1 github.com/ecnepsnai/osquery v1.0.1 @@ -40,6 +41,7 @@ require ( github.com/neilotoole/slogt v1.1.0 github.com/nightlyone/lockfile v1.0.0 github.com/otiai10/copy v1.14.0 + github.com/peterbourgon/diskv v2.0.1+incompatible github.com/ryboe/q v1.0.20 github.com/samber/lo v1.39.0 github.com/segmentio/encoding v0.3.7 @@ -66,10 +68,12 @@ require ( github.com/Masterminds/semver/v3 v3.2.0 // indirect github.com/VividCortex/ewma v1.2.0 // indirect github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect + github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/djherbis/atime v1.1.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect + github.com/google/btree v1.0.1 // indirect github.com/huandu/xstrings v1.3.3 // indirect github.com/imdario/mergo v0.3.11 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 6866c715d..ec24d70b0 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,10 @@ github.com/alessio/shellescape v1.4.2 h1:MHPfaU+ddJ0/bYWpgIeUnQUqKrlJ1S7BfEYPM4u github.com/alessio/shellescape v1.4.2/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= +github.com/bitcomplete/httpcache v0.0.0-20220528171057-1f4a71bbffc5 h1:W0eIasTpyKf6lU9jf5WhOpA0GcmnusfoL68U4VjSiwE= +github.com/bitcomplete/httpcache v0.0.0-20220528171057-1f4a71bbffc5/go.mod h1:bV6DTY4iwX8E6H3G//Ug6G3GmbLoFteBgBgmM9HYZDw= +github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 h1:N7oVaKyGp8bttX0bfZGmcGkjz7DLQXhAn3DNd3T0ous= +github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c= github.com/c2h5oh/datasize v0.0.0-20220606134207-859f65c6625b h1:6+ZFm0flnudZzdSE0JxlhR2hKnGPcNB35BjQf4RYQDY= github.com/c2h5oh/datasize v0.0.0-20220606134207-859f65c6625b/go.mod h1:S/7n9copUssQ56c7aAgHqftWO4LTf4xY6CGWt8Bc+3M= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= @@ -63,6 +67,8 @@ github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0kt github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -137,6 +143,8 @@ github.com/otiai10/copy v1.14.0 h1:dCI/t1iTdYGtkvCuBG2BgR6KZa83PTclw4U5n2wAllU= github.com/otiai10/copy v1.14.0/go.mod h1:ECfuL02W+/FkTWZWgQqXPWZgW9oeKCSQ5qVfSc4qc4w= github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks= github.com/otiai10/mint v1.5.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= +github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI= +github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/libsq/core/ioz/httpcache/LICENSE.txt b/libsq/core/ioz/httpcache/LICENSE.txt new file mode 100644 index 000000000..81316beb0 --- /dev/null +++ b/libsq/core/ioz/httpcache/LICENSE.txt @@ -0,0 +1,7 @@ +Copyright © 2012 Greg Jones (greg.jones@gmail.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/libsq/core/ioz/httpcache/README.md b/libsq/core/ioz/httpcache/README.md new file mode 100644 index 000000000..58cda222f --- /dev/null +++ b/libsq/core/ioz/httpcache/README.md @@ -0,0 +1,42 @@ +# httpcache + +[![GoDoc](https://godoc.org/github.com/bitcomplete/httpcache?status.svg)](https://godoc.org/github.com/bitcomplete/httpcache) + +Package httpcache provides a http.RoundTripper implementation that works as a +mostly [RFC 7234](https://tools.ietf.org/html/rfc7234) compliant cache for http +responses. This incarnation of the library is an active fork of +[github.com/gregjones/httpcache](https://github.com/gregjones/httpcache) which +is unmaintained. + +It is only suitable for use as a 'private' cache (i.e. for a web-browser or an +API-client and not for a shared proxy). + +## Cache Backends + +- The built-in 'memory' cache stores responses in an in-memory map. - + [`github.com/bitcomplete/httpcache/diskcache`](https://github.com/bitcomplete/httpcache/tree/master/diskcache) + provides a filesystem-backed cache using the + [diskv](https://github.com/peterbourgon/diskv) library. - + [`github.com/bitcomplete/httpcache/memcache`](https://github.com/bitcomplete/httpcache/tree/master/memcache) + provides memcache implementations, for both App Engine and 'normal' memcache + servers. - + [`sourcegraph.com/sourcegraph/s3cache`](https://sourcegraph.com/github.com/sourcegraph/s3cache) + uses Amazon S3 for storage. - + [`github.com/bitcomplete/httpcache/leveldbcache`](https://github.com/bitcomplete/httpcache/tree/master/leveldbcache) + provides a filesystem-backed cache using + [leveldb](https://github.com/syndtr/goleveldb/leveldb). - + [`github.com/die-net/lrucache`](https://github.com/die-net/lrucache) provides an + in-memory cache that will evict least-recently used entries. - + [`github.com/die-net/lrucache/twotier`](https://github.com/die-net/lrucache/tree/master/twotier) + allows caches to be combined, for example to use lrucache above with a + persistent disk-cache. - + [`github.com/birkelund/boltdbcache`](https://github.com/birkelund/boltdbcache) + provides a BoltDB implementation (based on the + [bbolt](https://github.com/coreos/bbolt) fork). + +If you implement any other backend and wish it to be linked here, please send a +PR editing this file. + +## License + +- [MIT License](LICENSE.txt) diff --git a/libsq/core/ioz/httpcache/diskcache/diskcache.go b/libsq/core/ioz/httpcache/diskcache/diskcache.go new file mode 100644 index 000000000..4dd96e128 --- /dev/null +++ b/libsq/core/ioz/httpcache/diskcache/diskcache.go @@ -0,0 +1,63 @@ +// Package diskcache provides an implementation of httpcache.Cache that uses the diskv package +// to supplement an in-memory map with persistent storage +// +package diskcache + +import ( + "bytes" + "context" + "crypto/md5" + "encoding/hex" + "io" + + "github.com/peterbourgon/diskv" +) + +// Cache is an implementation of httpcache.Cache that supplements the in-memory map with persistent storage +type Cache struct { + d *diskv.Diskv +} + +// Get returns the response corresponding to key if present +func (c *Cache) Get(ctx context.Context, key string) (resp []byte, ok bool) { + key = keyToFilename(key) + resp, err := c.d.Read(key) + if err != nil { + return []byte{}, false + } + return resp, true +} + +// Set saves a response to the cache as key +func (c *Cache) Set(ctx context.Context, key string, resp []byte) { + key = keyToFilename(key) + _ = c.d.WriteStream(key, bytes.NewReader(resp), true) +} + +// Delete removes the response with key from the cache +func (c *Cache) Delete(ctx context.Context, key string) { + key = keyToFilename(key) + _ = c.d.Erase(key) +} + +func keyToFilename(key string) string { + h := md5.New() + _, _ = io.WriteString(h, key) + return hex.EncodeToString(h.Sum(nil)) +} + +// New returns a new Cache that will store files in basePath +func New(basePath string) *Cache { + return &Cache{ + d: diskv.New(diskv.Options{ + BasePath: basePath, + CacheSizeMax: 100 * 1024 * 1024, // 100MB + }), + } +} + +// NewWithDiskv returns a new Cache using the provided Diskv as underlying +// storage. +func NewWithDiskv(d *diskv.Diskv) *Cache { + return &Cache{d} +} diff --git a/libsq/core/ioz/httpcache/diskcache/diskcache_test.go b/libsq/core/ioz/httpcache/diskcache/diskcache_test.go new file mode 100644 index 000000000..3fe82273d --- /dev/null +++ b/libsq/core/ioz/httpcache/diskcache/diskcache_test.go @@ -0,0 +1,18 @@ +package diskcache + +import ( + "github.com/neilotoole/sq/libsq/core/ioz/httpcache/test" + "io/ioutil" + "os" + "testing" +) + +func TestDiskCache(t *testing.T) { + tempDir, err := ioutil.TempDir("", "httpcache") + if err != nil { + t.Fatalf("TempDir: %v", err) + } + defer os.RemoveAll(tempDir) + + test.Cache(t, New(tempDir)) +} diff --git a/libsq/core/ioz/httpcache/httpcache.go b/libsq/core/ioz/httpcache/httpcache.go new file mode 100644 index 000000000..51d283740 --- /dev/null +++ b/libsq/core/ioz/httpcache/httpcache.go @@ -0,0 +1,584 @@ +// Package httpcache provides a http.RoundTripper implementation that works as a +// mostly RFC-compliant cache for http responses. +// +// It is only suitable for use as a 'private' cache (i.e. for a web-browser or an API-client +// and not for a shared proxy). +// +package httpcache + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "io/ioutil" + "net/http" + "net/http/httputil" + "strings" + "sync" + "time" +) + +const ( + stale = iota + fresh + transparent + // XFromCache is the header added to responses that are returned from the cache + XFromCache = "X-From-Cache" +) + +// A Cache interface is used by the Transport to store and retrieve responses. +type Cache interface { + // Get returns the []byte representation of a cached response and a bool + // set to true if the value isn't empty + Get(ctx context.Context, key string) (responseBytes []byte, ok bool) + // Set stores the []byte representation of a response against a key + Set(ctx context.Context, key string, responseBytes []byte) + // Delete removes the value associated with the key + Delete(ctx context.Context, key string) +} + +type KeyFunc func(req *http.Request) string + +// DefaultKeyFunc returns the cache key for req +var DefaultKeyFunc = func(req *http.Request) string { + if req.Method == http.MethodGet { + return req.URL.String() + } else { + return req.Method + " " + req.URL.String() + } +} + +// CachedResponse returns the cached http.Response for req if present, and nil +// otherwise. +func CachedResponse(ctx context.Context, c Cache, key string, req *http.Request) (resp *http.Response, err error) { + cachedVal, ok := c.Get(ctx, key) + if !ok { + return + } + + b := bytes.NewBuffer(cachedVal) + return http.ReadResponse(bufio.NewReader(b), req) +} + +// MemoryCache is an implemtation of Cache that stores responses in an in-memory map. +type MemoryCache struct { + mu sync.RWMutex + items map[string][]byte +} + +// Get returns the []byte representation of the response and true if present, false if not +func (c *MemoryCache) Get(ctx context.Context, key string) (resp []byte, ok bool) { + c.mu.RLock() + resp, ok = c.items[key] + c.mu.RUnlock() + return resp, ok +} + +// Set saves response resp to the cache with key +func (c *MemoryCache) Set(ctx context.Context, key string, resp []byte) { + c.mu.Lock() + c.items[key] = resp + c.mu.Unlock() +} + +// Delete removes key from the cache +func (c *MemoryCache) Delete(ctx context.Context, key string) { + c.mu.Lock() + delete(c.items, key) + c.mu.Unlock() +} + +// NewMemoryCache returns a new Cache that will store items in an in-memory map +func NewMemoryCache() *MemoryCache { + c := &MemoryCache{items: map[string][]byte{}} + return c +} + +// TransportOpt is a configuration option for creating a new Transport +type TransportOpt func(t *Transport) + +// MarkCachedResponsesOpt configures a transport by setting MarkCachedResponses to true +func MarkCachedResponsesOpt(markCachedResponses bool) TransportOpt { + return func(t *Transport) { + t.MarkCachedResponses = markCachedResponses + } +} + +// KeyFuncOpt configures a transport by setting its KeyFunc to the one given +func KeyFuncOpt(keyFunc KeyFunc) TransportOpt { + return func(t *Transport) { + t.KeyFunc = keyFunc + } +} + +// Transport is an implementation of http.RoundTripper that will return values from a cache +// where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) +// to repeated requests allowing servers to return 304 / Not Modified +type Transport struct { + // The RoundTripper interface actually used to make requests + // If nil, http.DefaultTransport is used + Transport http.RoundTripper + Cache Cache + // If true, responses returned from the cache will be given an extra header, X-From-Cache + MarkCachedResponses bool + // A function to generate a cache key for the given request + KeyFunc KeyFunc +} + +// NewTransport returns a new Transport with the provided Cache and options. If +// KeyFunc is not specified in opts then DefaultKeyFunc is used. +func NewTransport(c Cache, opts ...TransportOpt) *Transport { + t := &Transport{ + Cache: c, + KeyFunc: DefaultKeyFunc, + MarkCachedResponses: true, + } + for _, opt := range opts { + opt(t) + } + return t +} + +// Client returns an *http.Client that caches responses. +func (t *Transport) Client() *http.Client { + return &http.Client{Transport: t} +} + +// varyMatches will return false unless all of the cached values for the headers listed in Vary +// match the new request +func varyMatches(cachedResp *http.Response, req *http.Request) bool { + for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { + header = http.CanonicalHeaderKey(header) + if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) { + return false + } + } + return true +} + +// RoundTrip takes a Request and returns a Response +// +// If there is a fresh Response already in cache, then it will be returned without connecting to +// the server. +// +// If there is a stale Response, then any validators it contains will be set on the new request +// to give the server a chance to respond with NotModified. If this happens, then the cached Response +// will be returned. +func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + cacheKey := t.KeyFunc(req) + cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" + var cachedResp *http.Response + if cacheable { + cachedResp, err = CachedResponse(req.Context(), t.Cache, cacheKey, req) + } else { + // Need to invalidate an existing value + t.Cache.Delete(req.Context(), cacheKey) + } + + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + + if cacheable && cachedResp != nil && err == nil { + if t.MarkCachedResponses { + cachedResp.Header.Set(XFromCache, "1") + } + + if varyMatches(cachedResp, req) { + // Can only use cached value if the new request doesn't Vary significantly + freshness := getFreshness(cachedResp.Header, req.Header) + if freshness == fresh { + return cachedResp, nil + } + + if freshness == stale { + var req2 *http.Request + // Add validators if caller hasn't already done so + etag := cachedResp.Header.Get("etag") + if etag != "" && req.Header.Get("etag") == "" { + req2 = cloneRequest(req) + req2.Header.Set("if-none-match", etag) + } + lastModified := cachedResp.Header.Get("last-modified") + if lastModified != "" && req.Header.Get("last-modified") == "" { + if req2 == nil { + req2 = cloneRequest(req) + } + req2.Header.Set("if-modified-since", lastModified) + } + if req2 != nil { + req = req2 + } + } + } + + resp, err = transport.RoundTrip(req) + if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified { + // Replace the 304 response with the one from cache, but update with some new headers + endToEndHeaders := getEndToEndHeaders(resp.Header) + for _, header := range endToEndHeaders { + cachedResp.Header[header] = resp.Header[header] + } + resp = cachedResp + } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && + req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) { + // In case of transport failure and stale-if-error activated, returns cached content + // when available + return cachedResp, nil + } else { + if err != nil || resp.StatusCode != http.StatusOK { + t.Cache.Delete(req.Context(), cacheKey) + } + if err != nil { + return nil, err + } + } + } else { + reqCacheControl := parseCacheControl(req.Header) + if _, ok := reqCacheControl["only-if-cached"]; ok { + resp = newGatewayTimeoutResponse(req) + } else { + resp, err = transport.RoundTrip(req) + if err != nil { + return nil, err + } + } + } + + if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { + for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { + varyKey = http.CanonicalHeaderKey(varyKey) + fakeHeader := "X-Varied-" + varyKey + reqValue := req.Header.Get(varyKey) + if reqValue != "" { + resp.Header.Set(fakeHeader, reqValue) + } + } + switch req.Method { + case "GET": + // Delay caching until EOF is reached. + resp.Body = &cachingReadCloser{ + R: resp.Body, + OnEOF: func(r io.Reader) { + resp := *resp + resp.Body = ioutil.NopCloser(r) + respBytes, err := httputil.DumpResponse(&resp, true) + if err == nil { + t.Cache.Set(req.Context(), cacheKey, respBytes) + } + }, + } + default: + respBytes, err := httputil.DumpResponse(resp, true) + if err == nil { + t.Cache.Set(req.Context(), cacheKey, respBytes) + } + } + } else { + t.Cache.Delete(req.Context(), cacheKey) + } + return resp, nil +} + +// ErrNoDateHeader indicates that the HTTP headers contained no Date header. +var ErrNoDateHeader = errors.New("no Date header") + +// Date parses and returns the value of the Date header. +func Date(respHeaders http.Header) (date time.Time, err error) { + dateHeader := respHeaders.Get("date") + if dateHeader == "" { + err = ErrNoDateHeader + return + } + + return time.Parse(time.RFC1123, dateHeader) +} + +type realClock struct{} + +func (c *realClock) since(d time.Time) time.Duration { + return time.Since(d) +} + +type timer interface { + since(d time.Time) time.Duration +} + +var clock timer = &realClock{} + +// getFreshness will return one of fresh/stale/transparent based on the cache-control +// values of the request and the response +// +// fresh indicates the response can be returned +// stale indicates that the response needs validating before it is returned +// transparent indicates the response should not be used to fulfil the request +// +// Because this is only a private cache, 'public' and 'private' in cache-control aren't +// signficant. Similarly, smax-age isn't used. +func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { + respCacheControl := parseCacheControl(respHeaders) + reqCacheControl := parseCacheControl(reqHeaders) + if _, ok := reqCacheControl["no-cache"]; ok { + return transparent + } + if _, ok := respCacheControl["no-cache"]; ok { + return stale + } + if _, ok := reqCacheControl["only-if-cached"]; ok { + return fresh + } + + date, err := Date(respHeaders) + if err != nil { + return stale + } + currentAge := clock.since(date) + + var lifetime time.Duration + var zeroDuration time.Duration + + // If a response includes both an Expires header and a max-age directive, + // the max-age directive overrides the Expires header, even if the Expires header is more restrictive. + if maxAge, ok := respCacheControl["max-age"]; ok { + lifetime, err = time.ParseDuration(maxAge + "s") + if err != nil { + lifetime = zeroDuration + } + } else { + expiresHeader := respHeaders.Get("Expires") + if expiresHeader != "" { + expires, err := time.Parse(time.RFC1123, expiresHeader) + if err != nil { + lifetime = zeroDuration + } else { + lifetime = expires.Sub(date) + } + } + } + + if maxAge, ok := reqCacheControl["max-age"]; ok { + // the client is willing to accept a response whose age is no greater than the specified time in seconds + lifetime, err = time.ParseDuration(maxAge + "s") + if err != nil { + lifetime = zeroDuration + } + } + if minfresh, ok := reqCacheControl["min-fresh"]; ok { + // the client wants a response that will still be fresh for at least the specified number of seconds. + minfreshDuration, err := time.ParseDuration(minfresh + "s") + if err == nil { + currentAge = time.Duration(currentAge + minfreshDuration) + } + } + + if maxstale, ok := reqCacheControl["max-stale"]; ok { + // Indicates that the client is willing to accept a response that has exceeded its expiration time. + // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded + // its expiration time by no more than the specified number of seconds. + // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age. + // + // Responses served only because of a max-stale value are supposed to have a Warning header added to them, + // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different + // return-value available here. + if maxstale == "" { + return fresh + } + maxstaleDuration, err := time.ParseDuration(maxstale + "s") + if err == nil { + currentAge = time.Duration(currentAge - maxstaleDuration) + } + } + + if lifetime > currentAge { + return fresh + } + + return stale +} + +// Returns true if either the request or the response includes the stale-if-error +// cache control extension: https://tools.ietf.org/html/rfc5861 +func canStaleOnError(respHeaders, reqHeaders http.Header) bool { + respCacheControl := parseCacheControl(respHeaders) + reqCacheControl := parseCacheControl(reqHeaders) + + var err error + lifetime := time.Duration(-1) + + if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok { + if staleMaxAge != "" { + lifetime, err = time.ParseDuration(staleMaxAge + "s") + if err != nil { + return false + } + } else { + return true + } + } + if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok { + if staleMaxAge != "" { + lifetime, err = time.ParseDuration(staleMaxAge + "s") + if err != nil { + return false + } + } else { + return true + } + } + + if lifetime >= 0 { + date, err := Date(respHeaders) + if err != nil { + return false + } + currentAge := clock.since(date) + if lifetime > currentAge { + return true + } + } + + return false +} + +func getEndToEndHeaders(respHeaders http.Header) []string { + // These headers are always hop-by-hop + hopByHopHeaders := map[string]struct{}{ + "Connection": {}, + "Keep-Alive": {}, + "Proxy-Authenticate": {}, + "Proxy-Authorization": {}, + "Te": {}, + "Trailers": {}, + "Transfer-Encoding": {}, + "Upgrade": {}, + } + + for _, extra := range strings.Split(respHeaders.Get("connection"), ",") { + // any header listed in connection, if present, is also considered hop-by-hop + if strings.Trim(extra, " ") != "" { + hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{} + } + } + endToEndHeaders := []string{} + for respHeader := range respHeaders { + if _, ok := hopByHopHeaders[respHeader]; !ok { + endToEndHeaders = append(endToEndHeaders, respHeader) + } + } + return endToEndHeaders +} + +func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) { + if _, ok := respCacheControl["no-store"]; ok { + return false + } + if _, ok := reqCacheControl["no-store"]; ok { + return false + } + return true +} + +func newGatewayTimeoutResponse(req *http.Request) *http.Response { + var braw bytes.Buffer + braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n") + resp, err := http.ReadResponse(bufio.NewReader(&braw), req) + if err != nil { + panic(err) + } + return resp +} + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +// (This function copyright goauth2 authors: https://code.google.com/p/goauth2) +func cloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + if ctx := r.Context(); ctx != nil { + r2 = r2.WithContext(ctx) + } + // deep copy of the Header + r2.Header = make(http.Header) + for k, s := range r.Header { + r2.Header[k] = s + } + return r2 +} + +type cacheControl map[string]string + +func parseCacheControl(headers http.Header) cacheControl { + cc := cacheControl{} + ccHeader := headers.Get("Cache-Control") + for _, part := range strings.Split(ccHeader, ",") { + part = strings.Trim(part, " ") + if part == "" { + continue + } + if strings.ContainsRune(part, '=') { + keyval := strings.Split(part, "=") + cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",") + } else { + cc[part] = "" + } + } + return cc +} + +// headerAllCommaSepValues returns all comma-separated values (each +// with whitespace trimmed) for header name in headers. According to +// Section 4.2 of the HTTP/1.1 spec +// (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2), +// values from multiple occurrences of a header should be concatenated, if +// the header's value is a comma-separated list. +func headerAllCommaSepValues(headers http.Header, name string) []string { + var vals []string + for _, val := range headers[http.CanonicalHeaderKey(name)] { + fields := strings.Split(val, ",") + for i, f := range fields { + fields[i] = strings.TrimSpace(f) + } + vals = append(vals, fields...) + } + return vals +} + +// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF +// handler with a full copy of the content read from R when EOF is +// reached. +type cachingReadCloser struct { + // Underlying ReadCloser. + R io.ReadCloser + // OnEOF is called with a copy of the content of R when EOF is reached. + OnEOF func(io.Reader) + + buf bytes.Buffer // buf stores a copy of the content of R. +} + +// Read reads the next len(p) bytes from R or until R is drained. The +// return value n is the number of bytes read. If R has no data to +// return, err is io.EOF and OnEOF is called with a full copy of what +// has been read so far. +func (r *cachingReadCloser) Read(p []byte) (n int, err error) { + n, err = r.R.Read(p) + r.buf.Write(p[:n]) + if err == io.EOF { + r.OnEOF(bytes.NewReader(r.buf.Bytes())) + } + return n, err +} + +func (r *cachingReadCloser) Close() error { + return r.R.Close() +} + +// NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation +func NewMemoryCacheTransport(opts ...TransportOpt) *Transport { + c := NewMemoryCache() + t := NewTransport(c, opts...) + return t +} diff --git a/libsq/core/ioz/httpcache/httpcache_test.go b/libsq/core/ioz/httpcache/httpcache_test.go new file mode 100644 index 000000000..9fb3b8f47 --- /dev/null +++ b/libsq/core/ioz/httpcache/httpcache_test.go @@ -0,0 +1,1475 @@ +package httpcache + +import ( + "bytes" + "errors" + "flag" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" +) + +var s struct { + server *httptest.Server + client http.Client + transport *Transport + done chan struct{} // Closed to unlock infinite handlers. +} + +type fakeClock struct { + elapsed time.Duration +} + +func (c *fakeClock) since(t time.Time) time.Duration { + return c.elapsed +} + +func TestMain(m *testing.M) { + flag.Parse() + setup() + code := m.Run() + teardown() + os.Exit(code) +} + +func setup() { + tp := NewMemoryCacheTransport() + client := http.Client{Transport: tp} + s.transport = tp + s.client = client + s.done = make(chan struct{}) + + mux := http.NewServeMux() + s.server = httptest.NewServer(mux) + + mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + })) + + mux.HandleFunc("/method", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + _, _ = w.Write([]byte(r.Method)) + })) + + mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lm := "Fri, 14 Dec 2010 01:01:50 GMT" + if r.Header.Get("if-modified-since") == lm { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("last-modified", lm) + if r.Header.Get("range") == "bytes=4-9" { + w.WriteHeader(http.StatusPartialContent) + _, _ = w.Write([]byte(" text ")) + return + } + _, _ = w.Write([]byte("Some text content")) + })) + + mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "no-store") + })) + + mux.HandleFunc("/etag", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + etag := "124567" + if r.Header.Get("if-none-match") == etag { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("etag", etag) + })) + + mux.HandleFunc("/lastmodified", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lm := "Fri, 14 Dec 2010 01:01:50 GMT" + if r.Header.Get("if-modified-since") == lm { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("last-modified", lm) + })) + + mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Vary", "Accept") + _, _ = w.Write([]byte("Some text content")) + })) + + mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Vary", "Accept, Accept-Language") + _, _ = w.Write([]byte("Some text content")) + })) + mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Add("Vary", "Accept") + w.Header().Add("Vary", "Accept-Language") + _, _ = w.Write([]byte("Some text content")) + })) + mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Vary", "X-Madeup-Header") + _, _ = w.Write([]byte("Some text content")) + })) + + mux.HandleFunc("/cachederror", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + etag := "abc" + if r.Header.Get("if-none-match") == etag { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("etag", etag) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("Not found")) + })) + + updateFieldsCounter := 0 + mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Counter", strconv.Itoa(updateFieldsCounter)) + w.Header().Set("Etag", `"e"`) + updateFieldsCounter++ + if r.Header.Get("if-none-match") != "" { + w.WriteHeader(http.StatusNotModified) + return + } + _, _ = w.Write([]byte("Some text content")) + })) + + // Take 3 seconds to return 200 OK (for testing client timeouts). + mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + })) + + mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for { + select { + case <-s.done: + return + default: + _, _ = w.Write([]byte{0}) + } + } + })) +} + +func teardown() { + close(s.done) + s.server.Close() +} + +func resetTest() { + s.transport.Cache = NewMemoryCache() + clock = &realClock{} +} + +// TestCacheableMethod ensures that uncacheable method does not get stored +// in cache and get incorrectly used for a following cacheable method request. +func TestCacheableMethod(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("POST", s.server.URL+"/method", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "POST"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/method", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "GET"; got != want { + t.Errorf("got wrong body %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "" { + t.Errorf("XFromCache header isn't blank") + } + } +} + +func TestDontServeHeadResponseToGetRequest(t *testing.T) { + resetTest() + url := s.server.URL + "/" + req, err := http.NewRequest(http.MethodHead, url, nil) + if err != nil { + t.Fatal(err) + } + _, err = s.client.Do(req) + if err != nil { + t.Fatal(err) + } + req, err = http.NewRequest(http.MethodGet, url, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.Header.Get(XFromCache) != "" { + t.Errorf("Cache should not match") + } +} + +func TestDontStorePartialRangeInCache(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("range", "bytes=4-9") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), " text "; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusPartialContent { + t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "Some text content"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "" { + t.Error("XFromCache header isn't blank") + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "Some text content"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "1" { + t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("range", "bytes=4-9") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), " text "; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusPartialContent { + t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) + } + } +} + +func TestCacheOnlyIfBodyRead(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + // We do not read the body + resp.Body.Close() + } + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatalf("XFromCache header isn't blank") + } + } +} + +func TestOnlyReadBodyOnDemand(t *testing.T) { + resetTest() + + req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) // This shouldn't hang forever. + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 10) // Only partially read the body. + _, err = resp.Body.Read(buf) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() +} + +func TestGetOnlyIfCachedHit(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("cache-control", "only-if-cached") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) + } + } +} + +func TestGetOnlyIfCachedMiss(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("cache-control", "only-if-cached") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + if resp.StatusCode != http.StatusGatewayTimeout { + t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) + } +} + +func TestGetNoStoreRequest(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("Cache-Control", "no-store") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetNoStoreResponse(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetWithEtag(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + // additional assertions to verify that 304 response is converted properly + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if _, ok := resp.Header["Connection"]; ok { + t.Fatalf("Connection header isn't absent") + } + } +} + +func TestGetWithLastModified(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } +} + +func TestGetWithVary(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") != "Accept" { + t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary")) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + req.Header.Set("Accept", "text/html") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetWithDoubleVary(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") == "" { + t.Fatalf(`Vary header is blank`) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + req.Header.Set("Accept-Language", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept-Language", "da") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetWith2VaryHeaders(t *testing.T) { + resetTest() + // Tests that multiple Vary headers' comma-separated lists are + // merged. See https://github.com/gregjones/httpcache/issues/27. + const ( + accept = "text/plain" + acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" + ) + req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", accept) + req.Header.Set("Accept-Language", acceptLanguage) + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") == "" { + t.Fatalf(`Vary header is blank`) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + req.Header.Set("Accept-Language", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept-Language", "da") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept-Language", acceptLanguage) + req.Header.Set("Accept", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept", "image/png") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } +} + +func TestGetVaryUnused(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") == "" { + t.Fatalf(`Vary header is blank`) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } +} + +func TestUpdateFields(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) + if err != nil { + t.Fatal(err) + } + var counter, counter2 string + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + counter = resp.Header.Get("x-counter") + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + counter2 = resp.Header.Get("x-counter") + } + if counter == counter2 { + t.Fatalf(`both "x-counter" values are equal: %v %v`, counter, counter2) + } +} + +// This tests the fix for https://github.com/gregjones/httpcache/issues/74. +// Previously, after validating a cached response, its StatusCode +// was incorrectly being replaced. +func TestCachedErrorsKeepStatus(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/cachederror", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, _ = io.Copy(ioutil.Discard, resp.Body) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("Status code isn't 404: %d", resp.StatusCode) + } + } +} + +func TestParseCacheControl(t *testing.T) { + resetTest() + h := http.Header{} + for range parseCacheControl(h) { + t.Fatal("cacheControl should be empty") + } + + h.Set("cache-control", "no-cache") + { + cc := parseCacheControl(h) + if _, ok := cc["foo"]; ok { + t.Error(`Value "foo" shouldn't exist`) + } + noCache, ok := cc["no-cache"] + if !ok { + t.Fatalf(`"no-cache" value isn't set`) + } + if noCache != "" { + t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) + } + } + h.Set("cache-control", "no-cache, max-age=3600") + { + cc := parseCacheControl(h) + noCache, ok := cc["no-cache"] + if !ok { + t.Fatalf(`"no-cache" value isn't set`) + } + if noCache != "" { + t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) + } + if cc["max-age"] != "3600" { + t.Fatalf(`"max-age" value isn't "3600": %v`, cc["max-age"]) + } + } +} + +func TestNoCacheRequestExpiration(t *testing.T) { + resetTest() + respHeaders := http.Header{} + respHeaders.Set("Cache-Control", "max-age=7200") + + reqHeaders := http.Header{} + reqHeaders.Set("Cache-Control", "no-cache") + if getFreshness(respHeaders, reqHeaders) != transparent { + t.Fatal("freshness isn't transparent") + } +} + +func TestNoCacheResponseExpiration(t *testing.T) { + resetTest() + respHeaders := http.Header{} + respHeaders.Set("Cache-Control", "no-cache") + respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestReqMustRevalidate(t *testing.T) { + resetTest() + // not paying attention to request setting max-stale means never returning stale + // responses, so always acting as if must-revalidate is set + respHeaders := http.Header{} + + reqHeaders := http.Header{} + reqHeaders.Set("Cache-Control", "must-revalidate") + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestRespMustRevalidate(t *testing.T) { + resetTest() + respHeaders := http.Header{} + respHeaders.Set("Cache-Control", "must-revalidate") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestFreshExpiration(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 3 * time.Second} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestMaxAge(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=2") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 3 * time.Second} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestMaxAgeZero(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=0") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestBothMaxAge(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=2") + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "max-age=0") + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestMinFreshWithExpires(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "min-fresh=1") + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + reqHeaders = http.Header{} + reqHeaders.Set("cache-control", "min-fresh=2") + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestEmptyMaxStale(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=20") + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "max-stale") + clock = &fakeClock{elapsed: 10 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 60 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } +} + +func TestMaxStaleValue(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=10") + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "max-stale=20") + clock = &fakeClock{elapsed: 5 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 15 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 30 * time.Second} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func containsHeader(headers []string, header string) bool { + for _, v := range headers { + if http.CanonicalHeaderKey(v) == http.CanonicalHeaderKey(header) { + return true + } + } + return false +} + +func TestGetEndToEndHeaders(t *testing.T) { + resetTest() + var ( + headers http.Header + end2end []string + ) + + headers = http.Header{} + headers.Set("content-type", "text/html") + headers.Set("te", "deflate") + + end2end = getEndToEndHeaders(headers) + if !containsHeader(end2end, "content-type") { + t.Fatal(`doesn't contain "content-type" header`) + } + if containsHeader(end2end, "te") { + t.Fatal(`doesn't contain "te" header`) + } + + headers = http.Header{} + headers.Set("connection", "content-type") + headers.Set("content-type", "text/csv") + headers.Set("te", "deflate") + end2end = getEndToEndHeaders(headers) + if containsHeader(end2end, "connection") { + t.Fatal(`doesn't contain "connection" header`) + } + if containsHeader(end2end, "content-type") { + t.Fatal(`doesn't contain "content-type" header`) + } + if containsHeader(end2end, "te") { + t.Fatal(`doesn't contain "te" header`) + } + + headers = http.Header{} + end2end = getEndToEndHeaders(headers) + if len(end2end) != 0 { + t.Fatal(`non-zero end2end headers`) + } + + headers = http.Header{} + headers.Set("connection", "content-type") + end2end = getEndToEndHeaders(headers) + if len(end2end) != 0 { + t.Fatal(`non-zero end2end headers`) + } +} + +type transportMock struct { + response *http.Response + err error +} + +func (t transportMock) RoundTrip(req *http.Request) (resp *http.Response, err error) { + return t.response, t.err +} + +func TestStaleIfErrorRequest(t *testing.T) { + resetTest() + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := NewMemoryCacheTransport() + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + r.Header.Set("Cache-Control", "stale-if-error") + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } +} + +func TestStaleIfErrorRequestLifetime(t *testing.T) { + resetTest() + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := NewMemoryCacheTransport() + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + r.Header.Set("Cache-Control", "stale-if-error=100") + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + + // Same for http errors + tmock.response = &http.Response{StatusCode: http.StatusInternalServerError} + tmock.err = nil + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + + // If failure last more than max stale, error is returned + clock = &fakeClock{elapsed: 200 * time.Second} + _, err = tp.RoundTrip(r) + if err != tmock.err { + t.Fatalf("got err %v, want %v", err, tmock.err) + } +} + +func TestStaleIfErrorResponse(t *testing.T) { + resetTest() + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache, stale-if-error"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := NewMemoryCacheTransport() + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } +} + +func TestStaleIfErrorResponseLifetime(t *testing.T) { + resetTest() + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache, stale-if-error=100"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := NewMemoryCacheTransport() + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + + // If failure last more than max stale, error is returned + clock = &fakeClock{elapsed: 200 * time.Second} + _, err = tp.RoundTrip(r) + if err != tmock.err { + t.Fatalf("got err %v, want %v", err, tmock.err) + } +} + +// This tests the fix for https://github.com/gregjones/httpcache/issues/74. +// Previously, after a stale response was used after encountering an error, +// its StatusCode was being incorrectly replaced. +func TestStaleIfErrorKeepsStatus(t *testing.T) { + resetTest() + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusNotFound), + StatusCode: http.StatusNotFound, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := NewMemoryCacheTransport() + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + r.Header.Set("Cache-Control", "stale-if-error") + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("Status wasn't 404: %d", resp.StatusCode) + } +} + +// Test that http.Client.Timeout is respected when cache transport is used. +// That is so as long as request cancellation is propagated correctly. +// In the past, that required CancelRequest to be implemented correctly, +// but modern http.Client uses Request.Cancel (or request context) instead, +// so we don't have to do anything. +func TestClientTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping timeout test in short mode") // Because it takes at least 3 seconds to run. + } + resetTest() + client := &http.Client{ + Transport: NewMemoryCacheTransport(), + Timeout: time.Second, + } + started := time.Now() + resp, err := client.Get(s.server.URL + "/3seconds") + taken := time.Since(started) + if err == nil { + t.Error("got nil error, want timeout error") + } + if resp != nil { + t.Error("got non-nil resp, want nil resp") + } + if taken >= 2*time.Second { + t.Error("client.Do took 2+ seconds, want < 2 seconds") + } +} diff --git a/libsq/core/ioz/httpcache/memcache/memcache.go b/libsq/core/ioz/httpcache/memcache/memcache.go new file mode 100644 index 000000000..bb9f79376 --- /dev/null +++ b/libsq/core/ioz/httpcache/memcache/memcache.go @@ -0,0 +1,62 @@ +//go:build !appengine + +// Package memcache provides an implementation of httpcache.Cache that uses +// gomemcache to store cached responses. +// +// When built for Google App Engine, this package will provide an +// implementation that uses App Engine's memcache service. See the +// appengine.go file in this package for details. +package memcache + +import ( + "context" + + "github.com/bradfitz/gomemcache/memcache" +) + +// Cache is an implementation of httpcache.Cache that caches responses in a +// memcache server. +type Cache struct { + *memcache.Client +} + +// cacheKey modifies an httpcache key for use in memcache. Specifically, it +// prefixes keys to avoid collision with other data stored in memcache. +func cacheKey(key string) string { + return "httpcache:" + key +} + +// Get returns the response corresponding to key if present. +func (c *Cache) Get(ctx context.Context, key string) (resp []byte, ok bool) { + item, err := c.Client.Get(cacheKey(key)) + if err != nil { + return nil, false + } + return item.Value, true +} + +// Set saves a response to the cache as key. +func (c *Cache) Set(ctx context.Context, key string, resp []byte) { + item := &memcache.Item{ + Key: cacheKey(key), + Value: resp, + } + _ = c.Client.Set(item) +} + +// Delete removes the response with key from the cache. +func (c *Cache) Delete(ctx context.Context, key string) { + _ = c.Client.Delete(cacheKey(key)) +} + +// New returns a new Cache using the provided memcache server(s) with equal +// weight. If a server is listed multiple times, it gets a proportional amount +// of weight. +func New(server ...string) *Cache { + return NewWithClient(memcache.New(server...)) +} + +// NewWithClient returns a new Cache with the given memcache client. +func NewWithClient(client *memcache.Client) *Cache { + return &Cache{client} +} diff --git a/libsq/core/ioz/httpcache/memcache/memcache_test.go b/libsq/core/ioz/httpcache/memcache/memcache_test.go new file mode 100644 index 000000000..b33ae8e25 --- /dev/null +++ b/libsq/core/ioz/httpcache/memcache/memcache_test.go @@ -0,0 +1,24 @@ +//go:build !appengine + +package memcache + +import ( + "net" + "testing" + + "github.com/neilotoole/sq/libsq/core/ioz/httpcache/test" +) + +const testServer = "localhost:11211" + +func TestMemCache(t *testing.T) { + conn, err := net.Dial("tcp", testServer) + if err != nil { + // TODO: rather than skip the test, fall back to a faked memcached server + t.Skipf("skipping test; no server running at %s", testServer) + } + _, _ = conn.Write([]byte("flush_all\r\n")) // flush memcache + conn.Close() + + test.Cache(t, New(testServer)) +} diff --git a/libsq/core/ioz/httpcache/test/test.go b/libsq/core/ioz/httpcache/test/test.go new file mode 100644 index 000000000..533c60aaf --- /dev/null +++ b/libsq/core/ioz/httpcache/test/test.go @@ -0,0 +1,36 @@ +package test + +import ( + "bytes" + "context" + "testing" + + "github.com/neilotoole/sq/libsq/core/ioz/httpcache" +) + +// Cache excercises a httpcache.Cache implementation. +func Cache(t *testing.T, cache httpcache.Cache) { + key := "testKey" + _, ok := cache.Get(context.Background(), key) + if ok { + t.Fatal("retrieved key before adding it") + } + + val := []byte("some bytes") + cache.Set(context.Background(), key, val) + + retVal, ok := cache.Get(context.Background(), key) + if !ok { + t.Fatal("could not retrieve an element we just added") + } + if !bytes.Equal(retVal, val) { + t.Fatal("retrieved a different value than what we put in") + } + + cache.Delete(context.Background(), key) + + _, ok = cache.Get(context.Background(), key) + if ok { + t.Fatal("deleted key still present") + } +} diff --git a/libsq/core/ioz/httpcache/test/test_test.go b/libsq/core/ioz/httpcache/test/test_test.go new file mode 100644 index 000000000..4f02f62a1 --- /dev/null +++ b/libsq/core/ioz/httpcache/test/test_test.go @@ -0,0 +1,12 @@ +package test_test + +import ( + "testing" + + "github.com/neilotoole/sq/libsq/core/ioz/httpcache" + "github.com/neilotoole/sq/libsq/core/ioz/httpcache/test" +) + +func TestMemoryCache(t *testing.T) { + test.Cache(t, httpcache.NewMemoryCache()) +} From 8f3122b21980d34c6a61a96b638120a365afe657 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 04:48:20 -0700 Subject: [PATCH 085/195] git work --- libsq/core/ioz/httpcacheog/LICENSE.txt | 7 + libsq/core/ioz/httpcacheog/README.md | 42 + .../ioz/httpcacheog/diskcache/diskcache.go | 63 + .../httpcacheog/diskcache/diskcache_test.go | 18 + libsq/core/ioz/httpcacheog/httpcache.go | 584 +++++++ libsq/core/ioz/httpcacheog/httpcache_test.go | 1475 +++++++++++++++++ .../core/ioz/httpcacheog/memcache/memcache.go | 62 + .../ioz/httpcacheog/memcache/memcache_test.go | 24 + libsq/core/ioz/httpcacheog/test/test.go | 36 + libsq/core/ioz/httpcacheog/test/test_test.go | 12 + 10 files changed, 2323 insertions(+) create mode 100644 libsq/core/ioz/httpcacheog/LICENSE.txt create mode 100644 libsq/core/ioz/httpcacheog/README.md create mode 100644 libsq/core/ioz/httpcacheog/diskcache/diskcache.go create mode 100644 libsq/core/ioz/httpcacheog/diskcache/diskcache_test.go create mode 100644 libsq/core/ioz/httpcacheog/httpcache.go create mode 100644 libsq/core/ioz/httpcacheog/httpcache_test.go create mode 100644 libsq/core/ioz/httpcacheog/memcache/memcache.go create mode 100644 libsq/core/ioz/httpcacheog/memcache/memcache_test.go create mode 100644 libsq/core/ioz/httpcacheog/test/test.go create mode 100644 libsq/core/ioz/httpcacheog/test/test_test.go diff --git a/libsq/core/ioz/httpcacheog/LICENSE.txt b/libsq/core/ioz/httpcacheog/LICENSE.txt new file mode 100644 index 000000000..81316beb0 --- /dev/null +++ b/libsq/core/ioz/httpcacheog/LICENSE.txt @@ -0,0 +1,7 @@ +Copyright © 2012 Greg Jones (greg.jones@gmail.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/libsq/core/ioz/httpcacheog/README.md b/libsq/core/ioz/httpcacheog/README.md new file mode 100644 index 000000000..58cda222f --- /dev/null +++ b/libsq/core/ioz/httpcacheog/README.md @@ -0,0 +1,42 @@ +# httpcache + +[![GoDoc](https://godoc.org/github.com/bitcomplete/httpcache?status.svg)](https://godoc.org/github.com/bitcomplete/httpcache) + +Package httpcache provides a http.RoundTripper implementation that works as a +mostly [RFC 7234](https://tools.ietf.org/html/rfc7234) compliant cache for http +responses. This incarnation of the library is an active fork of +[github.com/gregjones/httpcache](https://github.com/gregjones/httpcache) which +is unmaintained. + +It is only suitable for use as a 'private' cache (i.e. for a web-browser or an +API-client and not for a shared proxy). + +## Cache Backends + +- The built-in 'memory' cache stores responses in an in-memory map. - + [`github.com/bitcomplete/httpcache/diskcache`](https://github.com/bitcomplete/httpcache/tree/master/diskcache) + provides a filesystem-backed cache using the + [diskv](https://github.com/peterbourgon/diskv) library. - + [`github.com/bitcomplete/httpcache/memcache`](https://github.com/bitcomplete/httpcache/tree/master/memcache) + provides memcache implementations, for both App Engine and 'normal' memcache + servers. - + [`sourcegraph.com/sourcegraph/s3cache`](https://sourcegraph.com/github.com/sourcegraph/s3cache) + uses Amazon S3 for storage. - + [`github.com/bitcomplete/httpcache/leveldbcache`](https://github.com/bitcomplete/httpcache/tree/master/leveldbcache) + provides a filesystem-backed cache using + [leveldb](https://github.com/syndtr/goleveldb/leveldb). - + [`github.com/die-net/lrucache`](https://github.com/die-net/lrucache) provides an + in-memory cache that will evict least-recently used entries. - + [`github.com/die-net/lrucache/twotier`](https://github.com/die-net/lrucache/tree/master/twotier) + allows caches to be combined, for example to use lrucache above with a + persistent disk-cache. - + [`github.com/birkelund/boltdbcache`](https://github.com/birkelund/boltdbcache) + provides a BoltDB implementation (based on the + [bbolt](https://github.com/coreos/bbolt) fork). + +If you implement any other backend and wish it to be linked here, please send a +PR editing this file. + +## License + +- [MIT License](LICENSE.txt) diff --git a/libsq/core/ioz/httpcacheog/diskcache/diskcache.go b/libsq/core/ioz/httpcacheog/diskcache/diskcache.go new file mode 100644 index 000000000..4dd96e128 --- /dev/null +++ b/libsq/core/ioz/httpcacheog/diskcache/diskcache.go @@ -0,0 +1,63 @@ +// Package diskcache provides an implementation of httpcache.Cache that uses the diskv package +// to supplement an in-memory map with persistent storage +// +package diskcache + +import ( + "bytes" + "context" + "crypto/md5" + "encoding/hex" + "io" + + "github.com/peterbourgon/diskv" +) + +// Cache is an implementation of httpcache.Cache that supplements the in-memory map with persistent storage +type Cache struct { + d *diskv.Diskv +} + +// Get returns the response corresponding to key if present +func (c *Cache) Get(ctx context.Context, key string) (resp []byte, ok bool) { + key = keyToFilename(key) + resp, err := c.d.Read(key) + if err != nil { + return []byte{}, false + } + return resp, true +} + +// Set saves a response to the cache as key +func (c *Cache) Set(ctx context.Context, key string, resp []byte) { + key = keyToFilename(key) + _ = c.d.WriteStream(key, bytes.NewReader(resp), true) +} + +// Delete removes the response with key from the cache +func (c *Cache) Delete(ctx context.Context, key string) { + key = keyToFilename(key) + _ = c.d.Erase(key) +} + +func keyToFilename(key string) string { + h := md5.New() + _, _ = io.WriteString(h, key) + return hex.EncodeToString(h.Sum(nil)) +} + +// New returns a new Cache that will store files in basePath +func New(basePath string) *Cache { + return &Cache{ + d: diskv.New(diskv.Options{ + BasePath: basePath, + CacheSizeMax: 100 * 1024 * 1024, // 100MB + }), + } +} + +// NewWithDiskv returns a new Cache using the provided Diskv as underlying +// storage. +func NewWithDiskv(d *diskv.Diskv) *Cache { + return &Cache{d} +} diff --git a/libsq/core/ioz/httpcacheog/diskcache/diskcache_test.go b/libsq/core/ioz/httpcacheog/diskcache/diskcache_test.go new file mode 100644 index 000000000..8bf743697 --- /dev/null +++ b/libsq/core/ioz/httpcacheog/diskcache/diskcache_test.go @@ -0,0 +1,18 @@ +package diskcache + +import ( + "github.com/neilotoole/sq/libsq/core/ioz/httpcacheog/test" + "io/ioutil" + "os" + "testing" +) + +func TestDiskCache(t *testing.T) { + tempDir, err := ioutil.TempDir("", "httpcache") + if err != nil { + t.Fatalf("TempDir: %v", err) + } + defer os.RemoveAll(tempDir) + + test.Cache(t, New(tempDir)) +} diff --git a/libsq/core/ioz/httpcacheog/httpcache.go b/libsq/core/ioz/httpcacheog/httpcache.go new file mode 100644 index 000000000..51d283740 --- /dev/null +++ b/libsq/core/ioz/httpcacheog/httpcache.go @@ -0,0 +1,584 @@ +// Package httpcache provides a http.RoundTripper implementation that works as a +// mostly RFC-compliant cache for http responses. +// +// It is only suitable for use as a 'private' cache (i.e. for a web-browser or an API-client +// and not for a shared proxy). +// +package httpcache + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "io/ioutil" + "net/http" + "net/http/httputil" + "strings" + "sync" + "time" +) + +const ( + stale = iota + fresh + transparent + // XFromCache is the header added to responses that are returned from the cache + XFromCache = "X-From-Cache" +) + +// A Cache interface is used by the Transport to store and retrieve responses. +type Cache interface { + // Get returns the []byte representation of a cached response and a bool + // set to true if the value isn't empty + Get(ctx context.Context, key string) (responseBytes []byte, ok bool) + // Set stores the []byte representation of a response against a key + Set(ctx context.Context, key string, responseBytes []byte) + // Delete removes the value associated with the key + Delete(ctx context.Context, key string) +} + +type KeyFunc func(req *http.Request) string + +// DefaultKeyFunc returns the cache key for req +var DefaultKeyFunc = func(req *http.Request) string { + if req.Method == http.MethodGet { + return req.URL.String() + } else { + return req.Method + " " + req.URL.String() + } +} + +// CachedResponse returns the cached http.Response for req if present, and nil +// otherwise. +func CachedResponse(ctx context.Context, c Cache, key string, req *http.Request) (resp *http.Response, err error) { + cachedVal, ok := c.Get(ctx, key) + if !ok { + return + } + + b := bytes.NewBuffer(cachedVal) + return http.ReadResponse(bufio.NewReader(b), req) +} + +// MemoryCache is an implemtation of Cache that stores responses in an in-memory map. +type MemoryCache struct { + mu sync.RWMutex + items map[string][]byte +} + +// Get returns the []byte representation of the response and true if present, false if not +func (c *MemoryCache) Get(ctx context.Context, key string) (resp []byte, ok bool) { + c.mu.RLock() + resp, ok = c.items[key] + c.mu.RUnlock() + return resp, ok +} + +// Set saves response resp to the cache with key +func (c *MemoryCache) Set(ctx context.Context, key string, resp []byte) { + c.mu.Lock() + c.items[key] = resp + c.mu.Unlock() +} + +// Delete removes key from the cache +func (c *MemoryCache) Delete(ctx context.Context, key string) { + c.mu.Lock() + delete(c.items, key) + c.mu.Unlock() +} + +// NewMemoryCache returns a new Cache that will store items in an in-memory map +func NewMemoryCache() *MemoryCache { + c := &MemoryCache{items: map[string][]byte{}} + return c +} + +// TransportOpt is a configuration option for creating a new Transport +type TransportOpt func(t *Transport) + +// MarkCachedResponsesOpt configures a transport by setting MarkCachedResponses to true +func MarkCachedResponsesOpt(markCachedResponses bool) TransportOpt { + return func(t *Transport) { + t.MarkCachedResponses = markCachedResponses + } +} + +// KeyFuncOpt configures a transport by setting its KeyFunc to the one given +func KeyFuncOpt(keyFunc KeyFunc) TransportOpt { + return func(t *Transport) { + t.KeyFunc = keyFunc + } +} + +// Transport is an implementation of http.RoundTripper that will return values from a cache +// where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) +// to repeated requests allowing servers to return 304 / Not Modified +type Transport struct { + // The RoundTripper interface actually used to make requests + // If nil, http.DefaultTransport is used + Transport http.RoundTripper + Cache Cache + // If true, responses returned from the cache will be given an extra header, X-From-Cache + MarkCachedResponses bool + // A function to generate a cache key for the given request + KeyFunc KeyFunc +} + +// NewTransport returns a new Transport with the provided Cache and options. If +// KeyFunc is not specified in opts then DefaultKeyFunc is used. +func NewTransport(c Cache, opts ...TransportOpt) *Transport { + t := &Transport{ + Cache: c, + KeyFunc: DefaultKeyFunc, + MarkCachedResponses: true, + } + for _, opt := range opts { + opt(t) + } + return t +} + +// Client returns an *http.Client that caches responses. +func (t *Transport) Client() *http.Client { + return &http.Client{Transport: t} +} + +// varyMatches will return false unless all of the cached values for the headers listed in Vary +// match the new request +func varyMatches(cachedResp *http.Response, req *http.Request) bool { + for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { + header = http.CanonicalHeaderKey(header) + if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) { + return false + } + } + return true +} + +// RoundTrip takes a Request and returns a Response +// +// If there is a fresh Response already in cache, then it will be returned without connecting to +// the server. +// +// If there is a stale Response, then any validators it contains will be set on the new request +// to give the server a chance to respond with NotModified. If this happens, then the cached Response +// will be returned. +func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + cacheKey := t.KeyFunc(req) + cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" + var cachedResp *http.Response + if cacheable { + cachedResp, err = CachedResponse(req.Context(), t.Cache, cacheKey, req) + } else { + // Need to invalidate an existing value + t.Cache.Delete(req.Context(), cacheKey) + } + + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + + if cacheable && cachedResp != nil && err == nil { + if t.MarkCachedResponses { + cachedResp.Header.Set(XFromCache, "1") + } + + if varyMatches(cachedResp, req) { + // Can only use cached value if the new request doesn't Vary significantly + freshness := getFreshness(cachedResp.Header, req.Header) + if freshness == fresh { + return cachedResp, nil + } + + if freshness == stale { + var req2 *http.Request + // Add validators if caller hasn't already done so + etag := cachedResp.Header.Get("etag") + if etag != "" && req.Header.Get("etag") == "" { + req2 = cloneRequest(req) + req2.Header.Set("if-none-match", etag) + } + lastModified := cachedResp.Header.Get("last-modified") + if lastModified != "" && req.Header.Get("last-modified") == "" { + if req2 == nil { + req2 = cloneRequest(req) + } + req2.Header.Set("if-modified-since", lastModified) + } + if req2 != nil { + req = req2 + } + } + } + + resp, err = transport.RoundTrip(req) + if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified { + // Replace the 304 response with the one from cache, but update with some new headers + endToEndHeaders := getEndToEndHeaders(resp.Header) + for _, header := range endToEndHeaders { + cachedResp.Header[header] = resp.Header[header] + } + resp = cachedResp + } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && + req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) { + // In case of transport failure and stale-if-error activated, returns cached content + // when available + return cachedResp, nil + } else { + if err != nil || resp.StatusCode != http.StatusOK { + t.Cache.Delete(req.Context(), cacheKey) + } + if err != nil { + return nil, err + } + } + } else { + reqCacheControl := parseCacheControl(req.Header) + if _, ok := reqCacheControl["only-if-cached"]; ok { + resp = newGatewayTimeoutResponse(req) + } else { + resp, err = transport.RoundTrip(req) + if err != nil { + return nil, err + } + } + } + + if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { + for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { + varyKey = http.CanonicalHeaderKey(varyKey) + fakeHeader := "X-Varied-" + varyKey + reqValue := req.Header.Get(varyKey) + if reqValue != "" { + resp.Header.Set(fakeHeader, reqValue) + } + } + switch req.Method { + case "GET": + // Delay caching until EOF is reached. + resp.Body = &cachingReadCloser{ + R: resp.Body, + OnEOF: func(r io.Reader) { + resp := *resp + resp.Body = ioutil.NopCloser(r) + respBytes, err := httputil.DumpResponse(&resp, true) + if err == nil { + t.Cache.Set(req.Context(), cacheKey, respBytes) + } + }, + } + default: + respBytes, err := httputil.DumpResponse(resp, true) + if err == nil { + t.Cache.Set(req.Context(), cacheKey, respBytes) + } + } + } else { + t.Cache.Delete(req.Context(), cacheKey) + } + return resp, nil +} + +// ErrNoDateHeader indicates that the HTTP headers contained no Date header. +var ErrNoDateHeader = errors.New("no Date header") + +// Date parses and returns the value of the Date header. +func Date(respHeaders http.Header) (date time.Time, err error) { + dateHeader := respHeaders.Get("date") + if dateHeader == "" { + err = ErrNoDateHeader + return + } + + return time.Parse(time.RFC1123, dateHeader) +} + +type realClock struct{} + +func (c *realClock) since(d time.Time) time.Duration { + return time.Since(d) +} + +type timer interface { + since(d time.Time) time.Duration +} + +var clock timer = &realClock{} + +// getFreshness will return one of fresh/stale/transparent based on the cache-control +// values of the request and the response +// +// fresh indicates the response can be returned +// stale indicates that the response needs validating before it is returned +// transparent indicates the response should not be used to fulfil the request +// +// Because this is only a private cache, 'public' and 'private' in cache-control aren't +// signficant. Similarly, smax-age isn't used. +func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { + respCacheControl := parseCacheControl(respHeaders) + reqCacheControl := parseCacheControl(reqHeaders) + if _, ok := reqCacheControl["no-cache"]; ok { + return transparent + } + if _, ok := respCacheControl["no-cache"]; ok { + return stale + } + if _, ok := reqCacheControl["only-if-cached"]; ok { + return fresh + } + + date, err := Date(respHeaders) + if err != nil { + return stale + } + currentAge := clock.since(date) + + var lifetime time.Duration + var zeroDuration time.Duration + + // If a response includes both an Expires header and a max-age directive, + // the max-age directive overrides the Expires header, even if the Expires header is more restrictive. + if maxAge, ok := respCacheControl["max-age"]; ok { + lifetime, err = time.ParseDuration(maxAge + "s") + if err != nil { + lifetime = zeroDuration + } + } else { + expiresHeader := respHeaders.Get("Expires") + if expiresHeader != "" { + expires, err := time.Parse(time.RFC1123, expiresHeader) + if err != nil { + lifetime = zeroDuration + } else { + lifetime = expires.Sub(date) + } + } + } + + if maxAge, ok := reqCacheControl["max-age"]; ok { + // the client is willing to accept a response whose age is no greater than the specified time in seconds + lifetime, err = time.ParseDuration(maxAge + "s") + if err != nil { + lifetime = zeroDuration + } + } + if minfresh, ok := reqCacheControl["min-fresh"]; ok { + // the client wants a response that will still be fresh for at least the specified number of seconds. + minfreshDuration, err := time.ParseDuration(minfresh + "s") + if err == nil { + currentAge = time.Duration(currentAge + minfreshDuration) + } + } + + if maxstale, ok := reqCacheControl["max-stale"]; ok { + // Indicates that the client is willing to accept a response that has exceeded its expiration time. + // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded + // its expiration time by no more than the specified number of seconds. + // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age. + // + // Responses served only because of a max-stale value are supposed to have a Warning header added to them, + // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different + // return-value available here. + if maxstale == "" { + return fresh + } + maxstaleDuration, err := time.ParseDuration(maxstale + "s") + if err == nil { + currentAge = time.Duration(currentAge - maxstaleDuration) + } + } + + if lifetime > currentAge { + return fresh + } + + return stale +} + +// Returns true if either the request or the response includes the stale-if-error +// cache control extension: https://tools.ietf.org/html/rfc5861 +func canStaleOnError(respHeaders, reqHeaders http.Header) bool { + respCacheControl := parseCacheControl(respHeaders) + reqCacheControl := parseCacheControl(reqHeaders) + + var err error + lifetime := time.Duration(-1) + + if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok { + if staleMaxAge != "" { + lifetime, err = time.ParseDuration(staleMaxAge + "s") + if err != nil { + return false + } + } else { + return true + } + } + if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok { + if staleMaxAge != "" { + lifetime, err = time.ParseDuration(staleMaxAge + "s") + if err != nil { + return false + } + } else { + return true + } + } + + if lifetime >= 0 { + date, err := Date(respHeaders) + if err != nil { + return false + } + currentAge := clock.since(date) + if lifetime > currentAge { + return true + } + } + + return false +} + +func getEndToEndHeaders(respHeaders http.Header) []string { + // These headers are always hop-by-hop + hopByHopHeaders := map[string]struct{}{ + "Connection": {}, + "Keep-Alive": {}, + "Proxy-Authenticate": {}, + "Proxy-Authorization": {}, + "Te": {}, + "Trailers": {}, + "Transfer-Encoding": {}, + "Upgrade": {}, + } + + for _, extra := range strings.Split(respHeaders.Get("connection"), ",") { + // any header listed in connection, if present, is also considered hop-by-hop + if strings.Trim(extra, " ") != "" { + hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{} + } + } + endToEndHeaders := []string{} + for respHeader := range respHeaders { + if _, ok := hopByHopHeaders[respHeader]; !ok { + endToEndHeaders = append(endToEndHeaders, respHeader) + } + } + return endToEndHeaders +} + +func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) { + if _, ok := respCacheControl["no-store"]; ok { + return false + } + if _, ok := reqCacheControl["no-store"]; ok { + return false + } + return true +} + +func newGatewayTimeoutResponse(req *http.Request) *http.Response { + var braw bytes.Buffer + braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n") + resp, err := http.ReadResponse(bufio.NewReader(&braw), req) + if err != nil { + panic(err) + } + return resp +} + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +// (This function copyright goauth2 authors: https://code.google.com/p/goauth2) +func cloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + if ctx := r.Context(); ctx != nil { + r2 = r2.WithContext(ctx) + } + // deep copy of the Header + r2.Header = make(http.Header) + for k, s := range r.Header { + r2.Header[k] = s + } + return r2 +} + +type cacheControl map[string]string + +func parseCacheControl(headers http.Header) cacheControl { + cc := cacheControl{} + ccHeader := headers.Get("Cache-Control") + for _, part := range strings.Split(ccHeader, ",") { + part = strings.Trim(part, " ") + if part == "" { + continue + } + if strings.ContainsRune(part, '=') { + keyval := strings.Split(part, "=") + cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",") + } else { + cc[part] = "" + } + } + return cc +} + +// headerAllCommaSepValues returns all comma-separated values (each +// with whitespace trimmed) for header name in headers. According to +// Section 4.2 of the HTTP/1.1 spec +// (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2), +// values from multiple occurrences of a header should be concatenated, if +// the header's value is a comma-separated list. +func headerAllCommaSepValues(headers http.Header, name string) []string { + var vals []string + for _, val := range headers[http.CanonicalHeaderKey(name)] { + fields := strings.Split(val, ",") + for i, f := range fields { + fields[i] = strings.TrimSpace(f) + } + vals = append(vals, fields...) + } + return vals +} + +// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF +// handler with a full copy of the content read from R when EOF is +// reached. +type cachingReadCloser struct { + // Underlying ReadCloser. + R io.ReadCloser + // OnEOF is called with a copy of the content of R when EOF is reached. + OnEOF func(io.Reader) + + buf bytes.Buffer // buf stores a copy of the content of R. +} + +// Read reads the next len(p) bytes from R or until R is drained. The +// return value n is the number of bytes read. If R has no data to +// return, err is io.EOF and OnEOF is called with a full copy of what +// has been read so far. +func (r *cachingReadCloser) Read(p []byte) (n int, err error) { + n, err = r.R.Read(p) + r.buf.Write(p[:n]) + if err == io.EOF { + r.OnEOF(bytes.NewReader(r.buf.Bytes())) + } + return n, err +} + +func (r *cachingReadCloser) Close() error { + return r.R.Close() +} + +// NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation +func NewMemoryCacheTransport(opts ...TransportOpt) *Transport { + c := NewMemoryCache() + t := NewTransport(c, opts...) + return t +} diff --git a/libsq/core/ioz/httpcacheog/httpcache_test.go b/libsq/core/ioz/httpcacheog/httpcache_test.go new file mode 100644 index 000000000..9fb3b8f47 --- /dev/null +++ b/libsq/core/ioz/httpcacheog/httpcache_test.go @@ -0,0 +1,1475 @@ +package httpcache + +import ( + "bytes" + "errors" + "flag" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" +) + +var s struct { + server *httptest.Server + client http.Client + transport *Transport + done chan struct{} // Closed to unlock infinite handlers. +} + +type fakeClock struct { + elapsed time.Duration +} + +func (c *fakeClock) since(t time.Time) time.Duration { + return c.elapsed +} + +func TestMain(m *testing.M) { + flag.Parse() + setup() + code := m.Run() + teardown() + os.Exit(code) +} + +func setup() { + tp := NewMemoryCacheTransport() + client := http.Client{Transport: tp} + s.transport = tp + s.client = client + s.done = make(chan struct{}) + + mux := http.NewServeMux() + s.server = httptest.NewServer(mux) + + mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + })) + + mux.HandleFunc("/method", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + _, _ = w.Write([]byte(r.Method)) + })) + + mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lm := "Fri, 14 Dec 2010 01:01:50 GMT" + if r.Header.Get("if-modified-since") == lm { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("last-modified", lm) + if r.Header.Get("range") == "bytes=4-9" { + w.WriteHeader(http.StatusPartialContent) + _, _ = w.Write([]byte(" text ")) + return + } + _, _ = w.Write([]byte("Some text content")) + })) + + mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "no-store") + })) + + mux.HandleFunc("/etag", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + etag := "124567" + if r.Header.Get("if-none-match") == etag { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("etag", etag) + })) + + mux.HandleFunc("/lastmodified", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lm := "Fri, 14 Dec 2010 01:01:50 GMT" + if r.Header.Get("if-modified-since") == lm { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("last-modified", lm) + })) + + mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Vary", "Accept") + _, _ = w.Write([]byte("Some text content")) + })) + + mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Vary", "Accept, Accept-Language") + _, _ = w.Write([]byte("Some text content")) + })) + mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Add("Vary", "Accept") + w.Header().Add("Vary", "Accept-Language") + _, _ = w.Write([]byte("Some text content")) + })) + mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Vary", "X-Madeup-Header") + _, _ = w.Write([]byte("Some text content")) + })) + + mux.HandleFunc("/cachederror", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + etag := "abc" + if r.Header.Get("if-none-match") == etag { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("etag", etag) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("Not found")) + })) + + updateFieldsCounter := 0 + mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Counter", strconv.Itoa(updateFieldsCounter)) + w.Header().Set("Etag", `"e"`) + updateFieldsCounter++ + if r.Header.Get("if-none-match") != "" { + w.WriteHeader(http.StatusNotModified) + return + } + _, _ = w.Write([]byte("Some text content")) + })) + + // Take 3 seconds to return 200 OK (for testing client timeouts). + mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + })) + + mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for { + select { + case <-s.done: + return + default: + _, _ = w.Write([]byte{0}) + } + } + })) +} + +func teardown() { + close(s.done) + s.server.Close() +} + +func resetTest() { + s.transport.Cache = NewMemoryCache() + clock = &realClock{} +} + +// TestCacheableMethod ensures that uncacheable method does not get stored +// in cache and get incorrectly used for a following cacheable method request. +func TestCacheableMethod(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("POST", s.server.URL+"/method", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "POST"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/method", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "GET"; got != want { + t.Errorf("got wrong body %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "" { + t.Errorf("XFromCache header isn't blank") + } + } +} + +func TestDontServeHeadResponseToGetRequest(t *testing.T) { + resetTest() + url := s.server.URL + "/" + req, err := http.NewRequest(http.MethodHead, url, nil) + if err != nil { + t.Fatal(err) + } + _, err = s.client.Do(req) + if err != nil { + t.Fatal(err) + } + req, err = http.NewRequest(http.MethodGet, url, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.Header.Get(XFromCache) != "" { + t.Errorf("Cache should not match") + } +} + +func TestDontStorePartialRangeInCache(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("range", "bytes=4-9") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), " text "; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusPartialContent { + t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "Some text content"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "" { + t.Error("XFromCache header isn't blank") + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "Some text content"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "1" { + t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("range", "bytes=4-9") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), " text "; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusPartialContent { + t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) + } + } +} + +func TestCacheOnlyIfBodyRead(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + // We do not read the body + resp.Body.Close() + } + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatalf("XFromCache header isn't blank") + } + } +} + +func TestOnlyReadBodyOnDemand(t *testing.T) { + resetTest() + + req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) // This shouldn't hang forever. + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 10) // Only partially read the body. + _, err = resp.Body.Read(buf) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() +} + +func TestGetOnlyIfCachedHit(t *testing.T) { + resetTest() + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("cache-control", "only-if-cached") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) + } + } +} + +func TestGetOnlyIfCachedMiss(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("cache-control", "only-if-cached") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + if resp.StatusCode != http.StatusGatewayTimeout { + t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) + } +} + +func TestGetNoStoreRequest(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("Cache-Control", "no-store") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetNoStoreResponse(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetWithEtag(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + // additional assertions to verify that 304 response is converted properly + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if _, ok := resp.Header["Connection"]; ok { + t.Fatalf("Connection header isn't absent") + } + } +} + +func TestGetWithLastModified(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } +} + +func TestGetWithVary(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") != "Accept" { + t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary")) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + req.Header.Set("Accept", "text/html") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetWithDoubleVary(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") == "" { + t.Fatalf(`Vary header is blank`) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + req.Header.Set("Accept-Language", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept-Language", "da") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetWith2VaryHeaders(t *testing.T) { + resetTest() + // Tests that multiple Vary headers' comma-separated lists are + // merged. See https://github.com/gregjones/httpcache/issues/27. + const ( + accept = "text/plain" + acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" + ) + req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", accept) + req.Header.Set("Accept-Language", acceptLanguage) + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") == "" { + t.Fatalf(`Vary header is blank`) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + req.Header.Set("Accept-Language", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept-Language", "da") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept-Language", acceptLanguage) + req.Header.Set("Accept", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept", "image/png") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } +} + +func TestGetVaryUnused(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") == "" { + t.Fatalf(`Vary header is blank`) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } +} + +func TestUpdateFields(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) + if err != nil { + t.Fatal(err) + } + var counter, counter2 string + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + counter = resp.Header.Get("x-counter") + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + counter2 = resp.Header.Get("x-counter") + } + if counter == counter2 { + t.Fatalf(`both "x-counter" values are equal: %v %v`, counter, counter2) + } +} + +// This tests the fix for https://github.com/gregjones/httpcache/issues/74. +// Previously, after validating a cached response, its StatusCode +// was incorrectly being replaced. +func TestCachedErrorsKeepStatus(t *testing.T) { + resetTest() + req, err := http.NewRequest("GET", s.server.URL+"/cachederror", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, _ = io.Copy(ioutil.Discard, resp.Body) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("Status code isn't 404: %d", resp.StatusCode) + } + } +} + +func TestParseCacheControl(t *testing.T) { + resetTest() + h := http.Header{} + for range parseCacheControl(h) { + t.Fatal("cacheControl should be empty") + } + + h.Set("cache-control", "no-cache") + { + cc := parseCacheControl(h) + if _, ok := cc["foo"]; ok { + t.Error(`Value "foo" shouldn't exist`) + } + noCache, ok := cc["no-cache"] + if !ok { + t.Fatalf(`"no-cache" value isn't set`) + } + if noCache != "" { + t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) + } + } + h.Set("cache-control", "no-cache, max-age=3600") + { + cc := parseCacheControl(h) + noCache, ok := cc["no-cache"] + if !ok { + t.Fatalf(`"no-cache" value isn't set`) + } + if noCache != "" { + t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) + } + if cc["max-age"] != "3600" { + t.Fatalf(`"max-age" value isn't "3600": %v`, cc["max-age"]) + } + } +} + +func TestNoCacheRequestExpiration(t *testing.T) { + resetTest() + respHeaders := http.Header{} + respHeaders.Set("Cache-Control", "max-age=7200") + + reqHeaders := http.Header{} + reqHeaders.Set("Cache-Control", "no-cache") + if getFreshness(respHeaders, reqHeaders) != transparent { + t.Fatal("freshness isn't transparent") + } +} + +func TestNoCacheResponseExpiration(t *testing.T) { + resetTest() + respHeaders := http.Header{} + respHeaders.Set("Cache-Control", "no-cache") + respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestReqMustRevalidate(t *testing.T) { + resetTest() + // not paying attention to request setting max-stale means never returning stale + // responses, so always acting as if must-revalidate is set + respHeaders := http.Header{} + + reqHeaders := http.Header{} + reqHeaders.Set("Cache-Control", "must-revalidate") + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestRespMustRevalidate(t *testing.T) { + resetTest() + respHeaders := http.Header{} + respHeaders.Set("Cache-Control", "must-revalidate") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestFreshExpiration(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 3 * time.Second} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestMaxAge(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=2") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 3 * time.Second} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestMaxAgeZero(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=0") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestBothMaxAge(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=2") + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "max-age=0") + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestMinFreshWithExpires(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "min-fresh=1") + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + reqHeaders = http.Header{} + reqHeaders.Set("cache-control", "min-fresh=2") + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestEmptyMaxStale(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=20") + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "max-stale") + clock = &fakeClock{elapsed: 10 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 60 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } +} + +func TestMaxStaleValue(t *testing.T) { + resetTest() + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=10") + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "max-stale=20") + clock = &fakeClock{elapsed: 5 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 15 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 30 * time.Second} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func containsHeader(headers []string, header string) bool { + for _, v := range headers { + if http.CanonicalHeaderKey(v) == http.CanonicalHeaderKey(header) { + return true + } + } + return false +} + +func TestGetEndToEndHeaders(t *testing.T) { + resetTest() + var ( + headers http.Header + end2end []string + ) + + headers = http.Header{} + headers.Set("content-type", "text/html") + headers.Set("te", "deflate") + + end2end = getEndToEndHeaders(headers) + if !containsHeader(end2end, "content-type") { + t.Fatal(`doesn't contain "content-type" header`) + } + if containsHeader(end2end, "te") { + t.Fatal(`doesn't contain "te" header`) + } + + headers = http.Header{} + headers.Set("connection", "content-type") + headers.Set("content-type", "text/csv") + headers.Set("te", "deflate") + end2end = getEndToEndHeaders(headers) + if containsHeader(end2end, "connection") { + t.Fatal(`doesn't contain "connection" header`) + } + if containsHeader(end2end, "content-type") { + t.Fatal(`doesn't contain "content-type" header`) + } + if containsHeader(end2end, "te") { + t.Fatal(`doesn't contain "te" header`) + } + + headers = http.Header{} + end2end = getEndToEndHeaders(headers) + if len(end2end) != 0 { + t.Fatal(`non-zero end2end headers`) + } + + headers = http.Header{} + headers.Set("connection", "content-type") + end2end = getEndToEndHeaders(headers) + if len(end2end) != 0 { + t.Fatal(`non-zero end2end headers`) + } +} + +type transportMock struct { + response *http.Response + err error +} + +func (t transportMock) RoundTrip(req *http.Request) (resp *http.Response, err error) { + return t.response, t.err +} + +func TestStaleIfErrorRequest(t *testing.T) { + resetTest() + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := NewMemoryCacheTransport() + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + r.Header.Set("Cache-Control", "stale-if-error") + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } +} + +func TestStaleIfErrorRequestLifetime(t *testing.T) { + resetTest() + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := NewMemoryCacheTransport() + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + r.Header.Set("Cache-Control", "stale-if-error=100") + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + + // Same for http errors + tmock.response = &http.Response{StatusCode: http.StatusInternalServerError} + tmock.err = nil + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + + // If failure last more than max stale, error is returned + clock = &fakeClock{elapsed: 200 * time.Second} + _, err = tp.RoundTrip(r) + if err != tmock.err { + t.Fatalf("got err %v, want %v", err, tmock.err) + } +} + +func TestStaleIfErrorResponse(t *testing.T) { + resetTest() + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache, stale-if-error"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := NewMemoryCacheTransport() + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } +} + +func TestStaleIfErrorResponseLifetime(t *testing.T) { + resetTest() + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache, stale-if-error=100"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := NewMemoryCacheTransport() + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + + // If failure last more than max stale, error is returned + clock = &fakeClock{elapsed: 200 * time.Second} + _, err = tp.RoundTrip(r) + if err != tmock.err { + t.Fatalf("got err %v, want %v", err, tmock.err) + } +} + +// This tests the fix for https://github.com/gregjones/httpcache/issues/74. +// Previously, after a stale response was used after encountering an error, +// its StatusCode was being incorrectly replaced. +func TestStaleIfErrorKeepsStatus(t *testing.T) { + resetTest() + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusNotFound), + StatusCode: http.StatusNotFound, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := NewMemoryCacheTransport() + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + r.Header.Set("Cache-Control", "stale-if-error") + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("Status wasn't 404: %d", resp.StatusCode) + } +} + +// Test that http.Client.Timeout is respected when cache transport is used. +// That is so as long as request cancellation is propagated correctly. +// In the past, that required CancelRequest to be implemented correctly, +// but modern http.Client uses Request.Cancel (or request context) instead, +// so we don't have to do anything. +func TestClientTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping timeout test in short mode") // Because it takes at least 3 seconds to run. + } + resetTest() + client := &http.Client{ + Transport: NewMemoryCacheTransport(), + Timeout: time.Second, + } + started := time.Now() + resp, err := client.Get(s.server.URL + "/3seconds") + taken := time.Since(started) + if err == nil { + t.Error("got nil error, want timeout error") + } + if resp != nil { + t.Error("got non-nil resp, want nil resp") + } + if taken >= 2*time.Second { + t.Error("client.Do took 2+ seconds, want < 2 seconds") + } +} diff --git a/libsq/core/ioz/httpcacheog/memcache/memcache.go b/libsq/core/ioz/httpcacheog/memcache/memcache.go new file mode 100644 index 000000000..bb9f79376 --- /dev/null +++ b/libsq/core/ioz/httpcacheog/memcache/memcache.go @@ -0,0 +1,62 @@ +//go:build !appengine + +// Package memcache provides an implementation of httpcache.Cache that uses +// gomemcache to store cached responses. +// +// When built for Google App Engine, this package will provide an +// implementation that uses App Engine's memcache service. See the +// appengine.go file in this package for details. +package memcache + +import ( + "context" + + "github.com/bradfitz/gomemcache/memcache" +) + +// Cache is an implementation of httpcache.Cache that caches responses in a +// memcache server. +type Cache struct { + *memcache.Client +} + +// cacheKey modifies an httpcache key for use in memcache. Specifically, it +// prefixes keys to avoid collision with other data stored in memcache. +func cacheKey(key string) string { + return "httpcache:" + key +} + +// Get returns the response corresponding to key if present. +func (c *Cache) Get(ctx context.Context, key string) (resp []byte, ok bool) { + item, err := c.Client.Get(cacheKey(key)) + if err != nil { + return nil, false + } + return item.Value, true +} + +// Set saves a response to the cache as key. +func (c *Cache) Set(ctx context.Context, key string, resp []byte) { + item := &memcache.Item{ + Key: cacheKey(key), + Value: resp, + } + _ = c.Client.Set(item) +} + +// Delete removes the response with key from the cache. +func (c *Cache) Delete(ctx context.Context, key string) { + _ = c.Client.Delete(cacheKey(key)) +} + +// New returns a new Cache using the provided memcache server(s) with equal +// weight. If a server is listed multiple times, it gets a proportional amount +// of weight. +func New(server ...string) *Cache { + return NewWithClient(memcache.New(server...)) +} + +// NewWithClient returns a new Cache with the given memcache client. +func NewWithClient(client *memcache.Client) *Cache { + return &Cache{client} +} diff --git a/libsq/core/ioz/httpcacheog/memcache/memcache_test.go b/libsq/core/ioz/httpcacheog/memcache/memcache_test.go new file mode 100644 index 000000000..596a7d141 --- /dev/null +++ b/libsq/core/ioz/httpcacheog/memcache/memcache_test.go @@ -0,0 +1,24 @@ +//go:build !appengine + +package memcache + +import ( + "net" + "testing" + + "github.com/neilotoole/sq/libsq/core/ioz/httpcacheog/test" +) + +const testServer = "localhost:11211" + +func TestMemCache(t *testing.T) { + conn, err := net.Dial("tcp", testServer) + if err != nil { + // TODO: rather than skip the test, fall back to a faked memcached server + t.Skipf("skipping test; no server running at %s", testServer) + } + _, _ = conn.Write([]byte("flush_all\r\n")) // flush memcache + conn.Close() + + test.Cache(t, New(testServer)) +} diff --git a/libsq/core/ioz/httpcacheog/test/test.go b/libsq/core/ioz/httpcacheog/test/test.go new file mode 100644 index 000000000..8c6ff1350 --- /dev/null +++ b/libsq/core/ioz/httpcacheog/test/test.go @@ -0,0 +1,36 @@ +package test + +import ( + "bytes" + "context" + "testing" + + "github.com/neilotoole/sq/libsq/core/ioz/httpcacheog" +) + +// Cache excercises a httpcache.Cache implementation. +func Cache(t *testing.T, cache httpcache.Cache) { + key := "testKey" + _, ok := cache.Get(context.Background(), key) + if ok { + t.Fatal("retrieved key before adding it") + } + + val := []byte("some bytes") + cache.Set(context.Background(), key, val) + + retVal, ok := cache.Get(context.Background(), key) + if !ok { + t.Fatal("could not retrieve an element we just added") + } + if !bytes.Equal(retVal, val) { + t.Fatal("retrieved a different value than what we put in") + } + + cache.Delete(context.Background(), key) + + _, ok = cache.Get(context.Background(), key) + if ok { + t.Fatal("deleted key still present") + } +} diff --git a/libsq/core/ioz/httpcacheog/test/test_test.go b/libsq/core/ioz/httpcacheog/test/test_test.go new file mode 100644 index 000000000..cc49e572e --- /dev/null +++ b/libsq/core/ioz/httpcacheog/test/test_test.go @@ -0,0 +1,12 @@ +package test_test + +import ( + "testing" + + "github.com/neilotoole/sq/libsq/core/ioz/httpcacheog" + "github.com/neilotoole/sq/libsq/core/ioz/httpcacheog/test" +) + +func TestMemoryCache(t *testing.T) { + test.Cache(t, httpcache.NewMemoryCache()) +} From 07e1e616ec2f7174149694b9ae266142a02eadce Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 05:28:40 -0700 Subject: [PATCH 086/195] dl: configured for refactor --- libsq/core/ioz/httpcache/memcache/memcache.go | 62 ------------------- .../ioz/httpcache/memcache/memcache_test.go | 24 ------- .../core/ioz/httpcacheog/memcache/memcache.go | 62 ------------------- .../ioz/httpcacheog/memcache/memcache_test.go | 24 ------- 4 files changed, 172 deletions(-) delete mode 100644 libsq/core/ioz/httpcache/memcache/memcache.go delete mode 100644 libsq/core/ioz/httpcache/memcache/memcache_test.go delete mode 100644 libsq/core/ioz/httpcacheog/memcache/memcache.go delete mode 100644 libsq/core/ioz/httpcacheog/memcache/memcache_test.go diff --git a/libsq/core/ioz/httpcache/memcache/memcache.go b/libsq/core/ioz/httpcache/memcache/memcache.go deleted file mode 100644 index bb9f79376..000000000 --- a/libsq/core/ioz/httpcache/memcache/memcache.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:build !appengine - -// Package memcache provides an implementation of httpcache.Cache that uses -// gomemcache to store cached responses. -// -// When built for Google App Engine, this package will provide an -// implementation that uses App Engine's memcache service. See the -// appengine.go file in this package for details. -package memcache - -import ( - "context" - - "github.com/bradfitz/gomemcache/memcache" -) - -// Cache is an implementation of httpcache.Cache that caches responses in a -// memcache server. -type Cache struct { - *memcache.Client -} - -// cacheKey modifies an httpcache key for use in memcache. Specifically, it -// prefixes keys to avoid collision with other data stored in memcache. -func cacheKey(key string) string { - return "httpcache:" + key -} - -// Get returns the response corresponding to key if present. -func (c *Cache) Get(ctx context.Context, key string) (resp []byte, ok bool) { - item, err := c.Client.Get(cacheKey(key)) - if err != nil { - return nil, false - } - return item.Value, true -} - -// Set saves a response to the cache as key. -func (c *Cache) Set(ctx context.Context, key string, resp []byte) { - item := &memcache.Item{ - Key: cacheKey(key), - Value: resp, - } - _ = c.Client.Set(item) -} - -// Delete removes the response with key from the cache. -func (c *Cache) Delete(ctx context.Context, key string) { - _ = c.Client.Delete(cacheKey(key)) -} - -// New returns a new Cache using the provided memcache server(s) with equal -// weight. If a server is listed multiple times, it gets a proportional amount -// of weight. -func New(server ...string) *Cache { - return NewWithClient(memcache.New(server...)) -} - -// NewWithClient returns a new Cache with the given memcache client. -func NewWithClient(client *memcache.Client) *Cache { - return &Cache{client} -} diff --git a/libsq/core/ioz/httpcache/memcache/memcache_test.go b/libsq/core/ioz/httpcache/memcache/memcache_test.go deleted file mode 100644 index b33ae8e25..000000000 --- a/libsq/core/ioz/httpcache/memcache/memcache_test.go +++ /dev/null @@ -1,24 +0,0 @@ -//go:build !appengine - -package memcache - -import ( - "net" - "testing" - - "github.com/neilotoole/sq/libsq/core/ioz/httpcache/test" -) - -const testServer = "localhost:11211" - -func TestMemCache(t *testing.T) { - conn, err := net.Dial("tcp", testServer) - if err != nil { - // TODO: rather than skip the test, fall back to a faked memcached server - t.Skipf("skipping test; no server running at %s", testServer) - } - _, _ = conn.Write([]byte("flush_all\r\n")) // flush memcache - conn.Close() - - test.Cache(t, New(testServer)) -} diff --git a/libsq/core/ioz/httpcacheog/memcache/memcache.go b/libsq/core/ioz/httpcacheog/memcache/memcache.go deleted file mode 100644 index bb9f79376..000000000 --- a/libsq/core/ioz/httpcacheog/memcache/memcache.go +++ /dev/null @@ -1,62 +0,0 @@ -//go:build !appengine - -// Package memcache provides an implementation of httpcache.Cache that uses -// gomemcache to store cached responses. -// -// When built for Google App Engine, this package will provide an -// implementation that uses App Engine's memcache service. See the -// appengine.go file in this package for details. -package memcache - -import ( - "context" - - "github.com/bradfitz/gomemcache/memcache" -) - -// Cache is an implementation of httpcache.Cache that caches responses in a -// memcache server. -type Cache struct { - *memcache.Client -} - -// cacheKey modifies an httpcache key for use in memcache. Specifically, it -// prefixes keys to avoid collision with other data stored in memcache. -func cacheKey(key string) string { - return "httpcache:" + key -} - -// Get returns the response corresponding to key if present. -func (c *Cache) Get(ctx context.Context, key string) (resp []byte, ok bool) { - item, err := c.Client.Get(cacheKey(key)) - if err != nil { - return nil, false - } - return item.Value, true -} - -// Set saves a response to the cache as key. -func (c *Cache) Set(ctx context.Context, key string, resp []byte) { - item := &memcache.Item{ - Key: cacheKey(key), - Value: resp, - } - _ = c.Client.Set(item) -} - -// Delete removes the response with key from the cache. -func (c *Cache) Delete(ctx context.Context, key string) { - _ = c.Client.Delete(cacheKey(key)) -} - -// New returns a new Cache using the provided memcache server(s) with equal -// weight. If a server is listed multiple times, it gets a proportional amount -// of weight. -func New(server ...string) *Cache { - return NewWithClient(memcache.New(server...)) -} - -// NewWithClient returns a new Cache with the given memcache client. -func NewWithClient(client *memcache.Client) *Cache { - return &Cache{client} -} diff --git a/libsq/core/ioz/httpcacheog/memcache/memcache_test.go b/libsq/core/ioz/httpcacheog/memcache/memcache_test.go deleted file mode 100644 index 596a7d141..000000000 --- a/libsq/core/ioz/httpcacheog/memcache/memcache_test.go +++ /dev/null @@ -1,24 +0,0 @@ -//go:build !appengine - -package memcache - -import ( - "net" - "testing" - - "github.com/neilotoole/sq/libsq/core/ioz/httpcacheog/test" -) - -const testServer = "localhost:11211" - -func TestMemCache(t *testing.T) { - conn, err := net.Dial("tcp", testServer) - if err != nil { - // TODO: rather than skip the test, fall back to a faked memcached server - t.Skipf("skipping test; no server running at %s", testServer) - } - _, _ = conn.Write([]byte("flush_all\r\n")) // flush memcache - conn.Close() - - test.Cache(t, New(testServer)) -} From 8f9a0944dd9a85d2236d501fb2075e051806222c Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 05:29:41 -0700 Subject: [PATCH 087/195] dl: configured for refactor --- libsq/core/ioz/httpcacheog/LICENSE.txt | 7 - libsq/core/ioz/httpcacheog/README.md | 42 - .../ioz/httpcacheog/diskcache/diskcache.go | 63 - .../httpcacheog/diskcache/diskcache_test.go | 18 - libsq/core/ioz/httpcacheog/httpcache.go | 584 ------- libsq/core/ioz/httpcacheog/httpcache_test.go | 1475 ----------------- libsq/core/ioz/httpcacheog/test/test.go | 36 - libsq/core/ioz/httpcacheog/test/test_test.go | 12 - 8 files changed, 2237 deletions(-) delete mode 100644 libsq/core/ioz/httpcacheog/LICENSE.txt delete mode 100644 libsq/core/ioz/httpcacheog/README.md delete mode 100644 libsq/core/ioz/httpcacheog/diskcache/diskcache.go delete mode 100644 libsq/core/ioz/httpcacheog/diskcache/diskcache_test.go delete mode 100644 libsq/core/ioz/httpcacheog/httpcache.go delete mode 100644 libsq/core/ioz/httpcacheog/httpcache_test.go delete mode 100644 libsq/core/ioz/httpcacheog/test/test.go delete mode 100644 libsq/core/ioz/httpcacheog/test/test_test.go diff --git a/libsq/core/ioz/httpcacheog/LICENSE.txt b/libsq/core/ioz/httpcacheog/LICENSE.txt deleted file mode 100644 index 81316beb0..000000000 --- a/libsq/core/ioz/httpcacheog/LICENSE.txt +++ /dev/null @@ -1,7 +0,0 @@ -Copyright © 2012 Greg Jones (greg.jones@gmail.com) - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/libsq/core/ioz/httpcacheog/README.md b/libsq/core/ioz/httpcacheog/README.md deleted file mode 100644 index 58cda222f..000000000 --- a/libsq/core/ioz/httpcacheog/README.md +++ /dev/null @@ -1,42 +0,0 @@ -# httpcache - -[![GoDoc](https://godoc.org/github.com/bitcomplete/httpcache?status.svg)](https://godoc.org/github.com/bitcomplete/httpcache) - -Package httpcache provides a http.RoundTripper implementation that works as a -mostly [RFC 7234](https://tools.ietf.org/html/rfc7234) compliant cache for http -responses. This incarnation of the library is an active fork of -[github.com/gregjones/httpcache](https://github.com/gregjones/httpcache) which -is unmaintained. - -It is only suitable for use as a 'private' cache (i.e. for a web-browser or an -API-client and not for a shared proxy). - -## Cache Backends - -- The built-in 'memory' cache stores responses in an in-memory map. - - [`github.com/bitcomplete/httpcache/diskcache`](https://github.com/bitcomplete/httpcache/tree/master/diskcache) - provides a filesystem-backed cache using the - [diskv](https://github.com/peterbourgon/diskv) library. - - [`github.com/bitcomplete/httpcache/memcache`](https://github.com/bitcomplete/httpcache/tree/master/memcache) - provides memcache implementations, for both App Engine and 'normal' memcache - servers. - - [`sourcegraph.com/sourcegraph/s3cache`](https://sourcegraph.com/github.com/sourcegraph/s3cache) - uses Amazon S3 for storage. - - [`github.com/bitcomplete/httpcache/leveldbcache`](https://github.com/bitcomplete/httpcache/tree/master/leveldbcache) - provides a filesystem-backed cache using - [leveldb](https://github.com/syndtr/goleveldb/leveldb). - - [`github.com/die-net/lrucache`](https://github.com/die-net/lrucache) provides an - in-memory cache that will evict least-recently used entries. - - [`github.com/die-net/lrucache/twotier`](https://github.com/die-net/lrucache/tree/master/twotier) - allows caches to be combined, for example to use lrucache above with a - persistent disk-cache. - - [`github.com/birkelund/boltdbcache`](https://github.com/birkelund/boltdbcache) - provides a BoltDB implementation (based on the - [bbolt](https://github.com/coreos/bbolt) fork). - -If you implement any other backend and wish it to be linked here, please send a -PR editing this file. - -## License - -- [MIT License](LICENSE.txt) diff --git a/libsq/core/ioz/httpcacheog/diskcache/diskcache.go b/libsq/core/ioz/httpcacheog/diskcache/diskcache.go deleted file mode 100644 index 4dd96e128..000000000 --- a/libsq/core/ioz/httpcacheog/diskcache/diskcache.go +++ /dev/null @@ -1,63 +0,0 @@ -// Package diskcache provides an implementation of httpcache.Cache that uses the diskv package -// to supplement an in-memory map with persistent storage -// -package diskcache - -import ( - "bytes" - "context" - "crypto/md5" - "encoding/hex" - "io" - - "github.com/peterbourgon/diskv" -) - -// Cache is an implementation of httpcache.Cache that supplements the in-memory map with persistent storage -type Cache struct { - d *diskv.Diskv -} - -// Get returns the response corresponding to key if present -func (c *Cache) Get(ctx context.Context, key string) (resp []byte, ok bool) { - key = keyToFilename(key) - resp, err := c.d.Read(key) - if err != nil { - return []byte{}, false - } - return resp, true -} - -// Set saves a response to the cache as key -func (c *Cache) Set(ctx context.Context, key string, resp []byte) { - key = keyToFilename(key) - _ = c.d.WriteStream(key, bytes.NewReader(resp), true) -} - -// Delete removes the response with key from the cache -func (c *Cache) Delete(ctx context.Context, key string) { - key = keyToFilename(key) - _ = c.d.Erase(key) -} - -func keyToFilename(key string) string { - h := md5.New() - _, _ = io.WriteString(h, key) - return hex.EncodeToString(h.Sum(nil)) -} - -// New returns a new Cache that will store files in basePath -func New(basePath string) *Cache { - return &Cache{ - d: diskv.New(diskv.Options{ - BasePath: basePath, - CacheSizeMax: 100 * 1024 * 1024, // 100MB - }), - } -} - -// NewWithDiskv returns a new Cache using the provided Diskv as underlying -// storage. -func NewWithDiskv(d *diskv.Diskv) *Cache { - return &Cache{d} -} diff --git a/libsq/core/ioz/httpcacheog/diskcache/diskcache_test.go b/libsq/core/ioz/httpcacheog/diskcache/diskcache_test.go deleted file mode 100644 index 8bf743697..000000000 --- a/libsq/core/ioz/httpcacheog/diskcache/diskcache_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package diskcache - -import ( - "github.com/neilotoole/sq/libsq/core/ioz/httpcacheog/test" - "io/ioutil" - "os" - "testing" -) - -func TestDiskCache(t *testing.T) { - tempDir, err := ioutil.TempDir("", "httpcache") - if err != nil { - t.Fatalf("TempDir: %v", err) - } - defer os.RemoveAll(tempDir) - - test.Cache(t, New(tempDir)) -} diff --git a/libsq/core/ioz/httpcacheog/httpcache.go b/libsq/core/ioz/httpcacheog/httpcache.go deleted file mode 100644 index 51d283740..000000000 --- a/libsq/core/ioz/httpcacheog/httpcache.go +++ /dev/null @@ -1,584 +0,0 @@ -// Package httpcache provides a http.RoundTripper implementation that works as a -// mostly RFC-compliant cache for http responses. -// -// It is only suitable for use as a 'private' cache (i.e. for a web-browser or an API-client -// and not for a shared proxy). -// -package httpcache - -import ( - "bufio" - "bytes" - "context" - "errors" - "io" - "io/ioutil" - "net/http" - "net/http/httputil" - "strings" - "sync" - "time" -) - -const ( - stale = iota - fresh - transparent - // XFromCache is the header added to responses that are returned from the cache - XFromCache = "X-From-Cache" -) - -// A Cache interface is used by the Transport to store and retrieve responses. -type Cache interface { - // Get returns the []byte representation of a cached response and a bool - // set to true if the value isn't empty - Get(ctx context.Context, key string) (responseBytes []byte, ok bool) - // Set stores the []byte representation of a response against a key - Set(ctx context.Context, key string, responseBytes []byte) - // Delete removes the value associated with the key - Delete(ctx context.Context, key string) -} - -type KeyFunc func(req *http.Request) string - -// DefaultKeyFunc returns the cache key for req -var DefaultKeyFunc = func(req *http.Request) string { - if req.Method == http.MethodGet { - return req.URL.String() - } else { - return req.Method + " " + req.URL.String() - } -} - -// CachedResponse returns the cached http.Response for req if present, and nil -// otherwise. -func CachedResponse(ctx context.Context, c Cache, key string, req *http.Request) (resp *http.Response, err error) { - cachedVal, ok := c.Get(ctx, key) - if !ok { - return - } - - b := bytes.NewBuffer(cachedVal) - return http.ReadResponse(bufio.NewReader(b), req) -} - -// MemoryCache is an implemtation of Cache that stores responses in an in-memory map. -type MemoryCache struct { - mu sync.RWMutex - items map[string][]byte -} - -// Get returns the []byte representation of the response and true if present, false if not -func (c *MemoryCache) Get(ctx context.Context, key string) (resp []byte, ok bool) { - c.mu.RLock() - resp, ok = c.items[key] - c.mu.RUnlock() - return resp, ok -} - -// Set saves response resp to the cache with key -func (c *MemoryCache) Set(ctx context.Context, key string, resp []byte) { - c.mu.Lock() - c.items[key] = resp - c.mu.Unlock() -} - -// Delete removes key from the cache -func (c *MemoryCache) Delete(ctx context.Context, key string) { - c.mu.Lock() - delete(c.items, key) - c.mu.Unlock() -} - -// NewMemoryCache returns a new Cache that will store items in an in-memory map -func NewMemoryCache() *MemoryCache { - c := &MemoryCache{items: map[string][]byte{}} - return c -} - -// TransportOpt is a configuration option for creating a new Transport -type TransportOpt func(t *Transport) - -// MarkCachedResponsesOpt configures a transport by setting MarkCachedResponses to true -func MarkCachedResponsesOpt(markCachedResponses bool) TransportOpt { - return func(t *Transport) { - t.MarkCachedResponses = markCachedResponses - } -} - -// KeyFuncOpt configures a transport by setting its KeyFunc to the one given -func KeyFuncOpt(keyFunc KeyFunc) TransportOpt { - return func(t *Transport) { - t.KeyFunc = keyFunc - } -} - -// Transport is an implementation of http.RoundTripper that will return values from a cache -// where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) -// to repeated requests allowing servers to return 304 / Not Modified -type Transport struct { - // The RoundTripper interface actually used to make requests - // If nil, http.DefaultTransport is used - Transport http.RoundTripper - Cache Cache - // If true, responses returned from the cache will be given an extra header, X-From-Cache - MarkCachedResponses bool - // A function to generate a cache key for the given request - KeyFunc KeyFunc -} - -// NewTransport returns a new Transport with the provided Cache and options. If -// KeyFunc is not specified in opts then DefaultKeyFunc is used. -func NewTransport(c Cache, opts ...TransportOpt) *Transport { - t := &Transport{ - Cache: c, - KeyFunc: DefaultKeyFunc, - MarkCachedResponses: true, - } - for _, opt := range opts { - opt(t) - } - return t -} - -// Client returns an *http.Client that caches responses. -func (t *Transport) Client() *http.Client { - return &http.Client{Transport: t} -} - -// varyMatches will return false unless all of the cached values for the headers listed in Vary -// match the new request -func varyMatches(cachedResp *http.Response, req *http.Request) bool { - for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { - header = http.CanonicalHeaderKey(header) - if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) { - return false - } - } - return true -} - -// RoundTrip takes a Request and returns a Response -// -// If there is a fresh Response already in cache, then it will be returned without connecting to -// the server. -// -// If there is a stale Response, then any validators it contains will be set on the new request -// to give the server a chance to respond with NotModified. If this happens, then the cached Response -// will be returned. -func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { - cacheKey := t.KeyFunc(req) - cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" - var cachedResp *http.Response - if cacheable { - cachedResp, err = CachedResponse(req.Context(), t.Cache, cacheKey, req) - } else { - // Need to invalidate an existing value - t.Cache.Delete(req.Context(), cacheKey) - } - - transport := t.Transport - if transport == nil { - transport = http.DefaultTransport - } - - if cacheable && cachedResp != nil && err == nil { - if t.MarkCachedResponses { - cachedResp.Header.Set(XFromCache, "1") - } - - if varyMatches(cachedResp, req) { - // Can only use cached value if the new request doesn't Vary significantly - freshness := getFreshness(cachedResp.Header, req.Header) - if freshness == fresh { - return cachedResp, nil - } - - if freshness == stale { - var req2 *http.Request - // Add validators if caller hasn't already done so - etag := cachedResp.Header.Get("etag") - if etag != "" && req.Header.Get("etag") == "" { - req2 = cloneRequest(req) - req2.Header.Set("if-none-match", etag) - } - lastModified := cachedResp.Header.Get("last-modified") - if lastModified != "" && req.Header.Get("last-modified") == "" { - if req2 == nil { - req2 = cloneRequest(req) - } - req2.Header.Set("if-modified-since", lastModified) - } - if req2 != nil { - req = req2 - } - } - } - - resp, err = transport.RoundTrip(req) - if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified { - // Replace the 304 response with the one from cache, but update with some new headers - endToEndHeaders := getEndToEndHeaders(resp.Header) - for _, header := range endToEndHeaders { - cachedResp.Header[header] = resp.Header[header] - } - resp = cachedResp - } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && - req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) { - // In case of transport failure and stale-if-error activated, returns cached content - // when available - return cachedResp, nil - } else { - if err != nil || resp.StatusCode != http.StatusOK { - t.Cache.Delete(req.Context(), cacheKey) - } - if err != nil { - return nil, err - } - } - } else { - reqCacheControl := parseCacheControl(req.Header) - if _, ok := reqCacheControl["only-if-cached"]; ok { - resp = newGatewayTimeoutResponse(req) - } else { - resp, err = transport.RoundTrip(req) - if err != nil { - return nil, err - } - } - } - - if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { - for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { - varyKey = http.CanonicalHeaderKey(varyKey) - fakeHeader := "X-Varied-" + varyKey - reqValue := req.Header.Get(varyKey) - if reqValue != "" { - resp.Header.Set(fakeHeader, reqValue) - } - } - switch req.Method { - case "GET": - // Delay caching until EOF is reached. - resp.Body = &cachingReadCloser{ - R: resp.Body, - OnEOF: func(r io.Reader) { - resp := *resp - resp.Body = ioutil.NopCloser(r) - respBytes, err := httputil.DumpResponse(&resp, true) - if err == nil { - t.Cache.Set(req.Context(), cacheKey, respBytes) - } - }, - } - default: - respBytes, err := httputil.DumpResponse(resp, true) - if err == nil { - t.Cache.Set(req.Context(), cacheKey, respBytes) - } - } - } else { - t.Cache.Delete(req.Context(), cacheKey) - } - return resp, nil -} - -// ErrNoDateHeader indicates that the HTTP headers contained no Date header. -var ErrNoDateHeader = errors.New("no Date header") - -// Date parses and returns the value of the Date header. -func Date(respHeaders http.Header) (date time.Time, err error) { - dateHeader := respHeaders.Get("date") - if dateHeader == "" { - err = ErrNoDateHeader - return - } - - return time.Parse(time.RFC1123, dateHeader) -} - -type realClock struct{} - -func (c *realClock) since(d time.Time) time.Duration { - return time.Since(d) -} - -type timer interface { - since(d time.Time) time.Duration -} - -var clock timer = &realClock{} - -// getFreshness will return one of fresh/stale/transparent based on the cache-control -// values of the request and the response -// -// fresh indicates the response can be returned -// stale indicates that the response needs validating before it is returned -// transparent indicates the response should not be used to fulfil the request -// -// Because this is only a private cache, 'public' and 'private' in cache-control aren't -// signficant. Similarly, smax-age isn't used. -func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { - respCacheControl := parseCacheControl(respHeaders) - reqCacheControl := parseCacheControl(reqHeaders) - if _, ok := reqCacheControl["no-cache"]; ok { - return transparent - } - if _, ok := respCacheControl["no-cache"]; ok { - return stale - } - if _, ok := reqCacheControl["only-if-cached"]; ok { - return fresh - } - - date, err := Date(respHeaders) - if err != nil { - return stale - } - currentAge := clock.since(date) - - var lifetime time.Duration - var zeroDuration time.Duration - - // If a response includes both an Expires header and a max-age directive, - // the max-age directive overrides the Expires header, even if the Expires header is more restrictive. - if maxAge, ok := respCacheControl["max-age"]; ok { - lifetime, err = time.ParseDuration(maxAge + "s") - if err != nil { - lifetime = zeroDuration - } - } else { - expiresHeader := respHeaders.Get("Expires") - if expiresHeader != "" { - expires, err := time.Parse(time.RFC1123, expiresHeader) - if err != nil { - lifetime = zeroDuration - } else { - lifetime = expires.Sub(date) - } - } - } - - if maxAge, ok := reqCacheControl["max-age"]; ok { - // the client is willing to accept a response whose age is no greater than the specified time in seconds - lifetime, err = time.ParseDuration(maxAge + "s") - if err != nil { - lifetime = zeroDuration - } - } - if minfresh, ok := reqCacheControl["min-fresh"]; ok { - // the client wants a response that will still be fresh for at least the specified number of seconds. - minfreshDuration, err := time.ParseDuration(minfresh + "s") - if err == nil { - currentAge = time.Duration(currentAge + minfreshDuration) - } - } - - if maxstale, ok := reqCacheControl["max-stale"]; ok { - // Indicates that the client is willing to accept a response that has exceeded its expiration time. - // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded - // its expiration time by no more than the specified number of seconds. - // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age. - // - // Responses served only because of a max-stale value are supposed to have a Warning header added to them, - // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different - // return-value available here. - if maxstale == "" { - return fresh - } - maxstaleDuration, err := time.ParseDuration(maxstale + "s") - if err == nil { - currentAge = time.Duration(currentAge - maxstaleDuration) - } - } - - if lifetime > currentAge { - return fresh - } - - return stale -} - -// Returns true if either the request or the response includes the stale-if-error -// cache control extension: https://tools.ietf.org/html/rfc5861 -func canStaleOnError(respHeaders, reqHeaders http.Header) bool { - respCacheControl := parseCacheControl(respHeaders) - reqCacheControl := parseCacheControl(reqHeaders) - - var err error - lifetime := time.Duration(-1) - - if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok { - if staleMaxAge != "" { - lifetime, err = time.ParseDuration(staleMaxAge + "s") - if err != nil { - return false - } - } else { - return true - } - } - if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok { - if staleMaxAge != "" { - lifetime, err = time.ParseDuration(staleMaxAge + "s") - if err != nil { - return false - } - } else { - return true - } - } - - if lifetime >= 0 { - date, err := Date(respHeaders) - if err != nil { - return false - } - currentAge := clock.since(date) - if lifetime > currentAge { - return true - } - } - - return false -} - -func getEndToEndHeaders(respHeaders http.Header) []string { - // These headers are always hop-by-hop - hopByHopHeaders := map[string]struct{}{ - "Connection": {}, - "Keep-Alive": {}, - "Proxy-Authenticate": {}, - "Proxy-Authorization": {}, - "Te": {}, - "Trailers": {}, - "Transfer-Encoding": {}, - "Upgrade": {}, - } - - for _, extra := range strings.Split(respHeaders.Get("connection"), ",") { - // any header listed in connection, if present, is also considered hop-by-hop - if strings.Trim(extra, " ") != "" { - hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{} - } - } - endToEndHeaders := []string{} - for respHeader := range respHeaders { - if _, ok := hopByHopHeaders[respHeader]; !ok { - endToEndHeaders = append(endToEndHeaders, respHeader) - } - } - return endToEndHeaders -} - -func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) { - if _, ok := respCacheControl["no-store"]; ok { - return false - } - if _, ok := reqCacheControl["no-store"]; ok { - return false - } - return true -} - -func newGatewayTimeoutResponse(req *http.Request) *http.Response { - var braw bytes.Buffer - braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n") - resp, err := http.ReadResponse(bufio.NewReader(&braw), req) - if err != nil { - panic(err) - } - return resp -} - -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map. -// (This function copyright goauth2 authors: https://code.google.com/p/goauth2) -func cloneRequest(r *http.Request) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - if ctx := r.Context(); ctx != nil { - r2 = r2.WithContext(ctx) - } - // deep copy of the Header - r2.Header = make(http.Header) - for k, s := range r.Header { - r2.Header[k] = s - } - return r2 -} - -type cacheControl map[string]string - -func parseCacheControl(headers http.Header) cacheControl { - cc := cacheControl{} - ccHeader := headers.Get("Cache-Control") - for _, part := range strings.Split(ccHeader, ",") { - part = strings.Trim(part, " ") - if part == "" { - continue - } - if strings.ContainsRune(part, '=') { - keyval := strings.Split(part, "=") - cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",") - } else { - cc[part] = "" - } - } - return cc -} - -// headerAllCommaSepValues returns all comma-separated values (each -// with whitespace trimmed) for header name in headers. According to -// Section 4.2 of the HTTP/1.1 spec -// (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2), -// values from multiple occurrences of a header should be concatenated, if -// the header's value is a comma-separated list. -func headerAllCommaSepValues(headers http.Header, name string) []string { - var vals []string - for _, val := range headers[http.CanonicalHeaderKey(name)] { - fields := strings.Split(val, ",") - for i, f := range fields { - fields[i] = strings.TrimSpace(f) - } - vals = append(vals, fields...) - } - return vals -} - -// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF -// handler with a full copy of the content read from R when EOF is -// reached. -type cachingReadCloser struct { - // Underlying ReadCloser. - R io.ReadCloser - // OnEOF is called with a copy of the content of R when EOF is reached. - OnEOF func(io.Reader) - - buf bytes.Buffer // buf stores a copy of the content of R. -} - -// Read reads the next len(p) bytes from R or until R is drained. The -// return value n is the number of bytes read. If R has no data to -// return, err is io.EOF and OnEOF is called with a full copy of what -// has been read so far. -func (r *cachingReadCloser) Read(p []byte) (n int, err error) { - n, err = r.R.Read(p) - r.buf.Write(p[:n]) - if err == io.EOF { - r.OnEOF(bytes.NewReader(r.buf.Bytes())) - } - return n, err -} - -func (r *cachingReadCloser) Close() error { - return r.R.Close() -} - -// NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation -func NewMemoryCacheTransport(opts ...TransportOpt) *Transport { - c := NewMemoryCache() - t := NewTransport(c, opts...) - return t -} diff --git a/libsq/core/ioz/httpcacheog/httpcache_test.go b/libsq/core/ioz/httpcacheog/httpcache_test.go deleted file mode 100644 index 9fb3b8f47..000000000 --- a/libsq/core/ioz/httpcacheog/httpcache_test.go +++ /dev/null @@ -1,1475 +0,0 @@ -package httpcache - -import ( - "bytes" - "errors" - "flag" - "io" - "io/ioutil" - "net/http" - "net/http/httptest" - "os" - "strconv" - "testing" - "time" -) - -var s struct { - server *httptest.Server - client http.Client - transport *Transport - done chan struct{} // Closed to unlock infinite handlers. -} - -type fakeClock struct { - elapsed time.Duration -} - -func (c *fakeClock) since(t time.Time) time.Duration { - return c.elapsed -} - -func TestMain(m *testing.M) { - flag.Parse() - setup() - code := m.Run() - teardown() - os.Exit(code) -} - -func setup() { - tp := NewMemoryCacheTransport() - client := http.Client{Transport: tp} - s.transport = tp - s.client = client - s.done = make(chan struct{}) - - mux := http.NewServeMux() - s.server = httptest.NewServer(mux) - - mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - })) - - mux.HandleFunc("/method", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - _, _ = w.Write([]byte(r.Method)) - })) - - mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lm := "Fri, 14 Dec 2010 01:01:50 GMT" - if r.Header.Get("if-modified-since") == lm { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("last-modified", lm) - if r.Header.Get("range") == "bytes=4-9" { - w.WriteHeader(http.StatusPartialContent) - _, _ = w.Write([]byte(" text ")) - return - } - _, _ = w.Write([]byte("Some text content")) - })) - - mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "no-store") - })) - - mux.HandleFunc("/etag", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - etag := "124567" - if r.Header.Get("if-none-match") == etag { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("etag", etag) - })) - - mux.HandleFunc("/lastmodified", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lm := "Fri, 14 Dec 2010 01:01:50 GMT" - if r.Header.Get("if-modified-since") == lm { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("last-modified", lm) - })) - - mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Vary", "Accept") - _, _ = w.Write([]byte("Some text content")) - })) - - mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Vary", "Accept, Accept-Language") - _, _ = w.Write([]byte("Some text content")) - })) - mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Add("Vary", "Accept") - w.Header().Add("Vary", "Accept-Language") - _, _ = w.Write([]byte("Some text content")) - })) - mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Vary", "X-Madeup-Header") - _, _ = w.Write([]byte("Some text content")) - })) - - mux.HandleFunc("/cachederror", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - etag := "abc" - if r.Header.Get("if-none-match") == etag { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("etag", etag) - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte("Not found")) - })) - - updateFieldsCounter := 0 - mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Counter", strconv.Itoa(updateFieldsCounter)) - w.Header().Set("Etag", `"e"`) - updateFieldsCounter++ - if r.Header.Get("if-none-match") != "" { - w.WriteHeader(http.StatusNotModified) - return - } - _, _ = w.Write([]byte("Some text content")) - })) - - // Take 3 seconds to return 200 OK (for testing client timeouts). - mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(3 * time.Second) - })) - - mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for { - select { - case <-s.done: - return - default: - _, _ = w.Write([]byte{0}) - } - } - })) -} - -func teardown() { - close(s.done) - s.server.Close() -} - -func resetTest() { - s.transport.Cache = NewMemoryCache() - clock = &realClock{} -} - -// TestCacheableMethod ensures that uncacheable method does not get stored -// in cache and get incorrectly used for a following cacheable method request. -func TestCacheableMethod(t *testing.T) { - resetTest() - { - req, err := http.NewRequest("POST", s.server.URL+"/method", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "POST"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/method", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "GET"; got != want { - t.Errorf("got wrong body %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "" { - t.Errorf("XFromCache header isn't blank") - } - } -} - -func TestDontServeHeadResponseToGetRequest(t *testing.T) { - resetTest() - url := s.server.URL + "/" - req, err := http.NewRequest(http.MethodHead, url, nil) - if err != nil { - t.Fatal(err) - } - _, err = s.client.Do(req) - if err != nil { - t.Fatal(err) - } - req, err = http.NewRequest(http.MethodGet, url, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.Header.Get(XFromCache) != "" { - t.Errorf("Cache should not match") - } -} - -func TestDontStorePartialRangeInCache(t *testing.T) { - resetTest() - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("range", "bytes=4-9") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), " text "; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusPartialContent { - t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "Some text content"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "" { - t.Error("XFromCache header isn't blank") - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "Some text content"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "1" { - t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("range", "bytes=4-9") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), " text "; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusPartialContent { - t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) - } - } -} - -func TestCacheOnlyIfBodyRead(t *testing.T) { - resetTest() - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - // We do not read the body - resp.Body.Close() - } - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatalf("XFromCache header isn't blank") - } - } -} - -func TestOnlyReadBodyOnDemand(t *testing.T) { - resetTest() - - req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) // This shouldn't hang forever. - if err != nil { - t.Fatal(err) - } - buf := make([]byte, 10) // Only partially read the body. - _, err = resp.Body.Read(buf) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() -} - -func TestGetOnlyIfCachedHit(t *testing.T) { - resetTest() - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("cache-control", "only-if-cached") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) - } - } -} - -func TestGetOnlyIfCachedMiss(t *testing.T) { - resetTest() - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("cache-control", "only-if-cached") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - if resp.StatusCode != http.StatusGatewayTimeout { - t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) - } -} - -func TestGetNoStoreRequest(t *testing.T) { - resetTest() - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("Cache-Control", "no-store") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetNoStoreResponse(t *testing.T) { - resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetWithEtag(t *testing.T) { - resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - // additional assertions to verify that 304 response is converted properly - if resp.StatusCode != http.StatusOK { - t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if _, ok := resp.Header["Connection"]; ok { - t.Fatalf("Connection header isn't absent") - } - } -} - -func TestGetWithLastModified(t *testing.T) { - resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } -} - -func TestGetWithVary(t *testing.T) { - resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", "text/plain") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") != "Accept" { - t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary")) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - req.Header.Set("Accept", "text/html") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetWithDoubleVary(t *testing.T) { - resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", "text/plain") - req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") == "" { - t.Fatalf(`Vary header is blank`) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - req.Header.Set("Accept-Language", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept-Language", "da") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetWith2VaryHeaders(t *testing.T) { - resetTest() - // Tests that multiple Vary headers' comma-separated lists are - // merged. See https://github.com/gregjones/httpcache/issues/27. - const ( - accept = "text/plain" - acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" - ) - req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", accept) - req.Header.Set("Accept-Language", acceptLanguage) - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") == "" { - t.Fatalf(`Vary header is blank`) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - req.Header.Set("Accept-Language", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept-Language", "da") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept-Language", acceptLanguage) - req.Header.Set("Accept", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept", "image/png") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } -} - -func TestGetVaryUnused(t *testing.T) { - resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", "text/plain") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") == "" { - t.Fatalf(`Vary header is blank`) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } -} - -func TestUpdateFields(t *testing.T) { - resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) - if err != nil { - t.Fatal(err) - } - var counter, counter2 string - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - counter = resp.Header.Get("x-counter") - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - counter2 = resp.Header.Get("x-counter") - } - if counter == counter2 { - t.Fatalf(`both "x-counter" values are equal: %v %v`, counter, counter2) - } -} - -// This tests the fix for https://github.com/gregjones/httpcache/issues/74. -// Previously, after validating a cached response, its StatusCode -// was incorrectly being replaced. -func TestCachedErrorsKeepStatus(t *testing.T) { - resetTest() - req, err := http.NewRequest("GET", s.server.URL+"/cachederror", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - _, _ = io.Copy(ioutil.Discard, resp.Body) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusNotFound { - t.Fatalf("Status code isn't 404: %d", resp.StatusCode) - } - } -} - -func TestParseCacheControl(t *testing.T) { - resetTest() - h := http.Header{} - for range parseCacheControl(h) { - t.Fatal("cacheControl should be empty") - } - - h.Set("cache-control", "no-cache") - { - cc := parseCacheControl(h) - if _, ok := cc["foo"]; ok { - t.Error(`Value "foo" shouldn't exist`) - } - noCache, ok := cc["no-cache"] - if !ok { - t.Fatalf(`"no-cache" value isn't set`) - } - if noCache != "" { - t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) - } - } - h.Set("cache-control", "no-cache, max-age=3600") - { - cc := parseCacheControl(h) - noCache, ok := cc["no-cache"] - if !ok { - t.Fatalf(`"no-cache" value isn't set`) - } - if noCache != "" { - t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) - } - if cc["max-age"] != "3600" { - t.Fatalf(`"max-age" value isn't "3600": %v`, cc["max-age"]) - } - } -} - -func TestNoCacheRequestExpiration(t *testing.T) { - resetTest() - respHeaders := http.Header{} - respHeaders.Set("Cache-Control", "max-age=7200") - - reqHeaders := http.Header{} - reqHeaders.Set("Cache-Control", "no-cache") - if getFreshness(respHeaders, reqHeaders) != transparent { - t.Fatal("freshness isn't transparent") - } -} - -func TestNoCacheResponseExpiration(t *testing.T) { - resetTest() - respHeaders := http.Header{} - respHeaders.Set("Cache-Control", "no-cache") - respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestReqMustRevalidate(t *testing.T) { - resetTest() - // not paying attention to request setting max-stale means never returning stale - // responses, so always acting as if must-revalidate is set - respHeaders := http.Header{} - - reqHeaders := http.Header{} - reqHeaders.Set("Cache-Control", "must-revalidate") - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestRespMustRevalidate(t *testing.T) { - resetTest() - respHeaders := http.Header{} - respHeaders.Set("Cache-Control", "must-revalidate") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestFreshExpiration(t *testing.T) { - resetTest() - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 3 * time.Second} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestMaxAge(t *testing.T) { - resetTest() - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=2") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 3 * time.Second} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestMaxAgeZero(t *testing.T) { - resetTest() - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=0") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestBothMaxAge(t *testing.T) { - resetTest() - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=2") - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "max-age=0") - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestMinFreshWithExpires(t *testing.T) { - resetTest() - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "min-fresh=1") - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - reqHeaders = http.Header{} - reqHeaders.Set("cache-control", "min-fresh=2") - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestEmptyMaxStale(t *testing.T) { - resetTest() - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=20") - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "max-stale") - clock = &fakeClock{elapsed: 10 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 60 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } -} - -func TestMaxStaleValue(t *testing.T) { - resetTest() - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=10") - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "max-stale=20") - clock = &fakeClock{elapsed: 5 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 15 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 30 * time.Second} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func containsHeader(headers []string, header string) bool { - for _, v := range headers { - if http.CanonicalHeaderKey(v) == http.CanonicalHeaderKey(header) { - return true - } - } - return false -} - -func TestGetEndToEndHeaders(t *testing.T) { - resetTest() - var ( - headers http.Header - end2end []string - ) - - headers = http.Header{} - headers.Set("content-type", "text/html") - headers.Set("te", "deflate") - - end2end = getEndToEndHeaders(headers) - if !containsHeader(end2end, "content-type") { - t.Fatal(`doesn't contain "content-type" header`) - } - if containsHeader(end2end, "te") { - t.Fatal(`doesn't contain "te" header`) - } - - headers = http.Header{} - headers.Set("connection", "content-type") - headers.Set("content-type", "text/csv") - headers.Set("te", "deflate") - end2end = getEndToEndHeaders(headers) - if containsHeader(end2end, "connection") { - t.Fatal(`doesn't contain "connection" header`) - } - if containsHeader(end2end, "content-type") { - t.Fatal(`doesn't contain "content-type" header`) - } - if containsHeader(end2end, "te") { - t.Fatal(`doesn't contain "te" header`) - } - - headers = http.Header{} - end2end = getEndToEndHeaders(headers) - if len(end2end) != 0 { - t.Fatal(`non-zero end2end headers`) - } - - headers = http.Header{} - headers.Set("connection", "content-type") - end2end = getEndToEndHeaders(headers) - if len(end2end) != 0 { - t.Fatal(`non-zero end2end headers`) - } -} - -type transportMock struct { - response *http.Response - err error -} - -func (t transportMock) RoundTrip(req *http.Request) (resp *http.Response, err error) { - return t.response, t.err -} - -func TestStaleIfErrorRequest(t *testing.T) { - resetTest() - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := NewMemoryCacheTransport() - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - r.Header.Set("Cache-Control", "stale-if-error") - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } -} - -func TestStaleIfErrorRequestLifetime(t *testing.T) { - resetTest() - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := NewMemoryCacheTransport() - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - r.Header.Set("Cache-Control", "stale-if-error=100") - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - - // Same for http errors - tmock.response = &http.Response{StatusCode: http.StatusInternalServerError} - tmock.err = nil - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - - // If failure last more than max stale, error is returned - clock = &fakeClock{elapsed: 200 * time.Second} - _, err = tp.RoundTrip(r) - if err != tmock.err { - t.Fatalf("got err %v, want %v", err, tmock.err) - } -} - -func TestStaleIfErrorResponse(t *testing.T) { - resetTest() - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache, stale-if-error"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := NewMemoryCacheTransport() - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } -} - -func TestStaleIfErrorResponseLifetime(t *testing.T) { - resetTest() - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache, stale-if-error=100"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := NewMemoryCacheTransport() - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - - // If failure last more than max stale, error is returned - clock = &fakeClock{elapsed: 200 * time.Second} - _, err = tp.RoundTrip(r) - if err != tmock.err { - t.Fatalf("got err %v, want %v", err, tmock.err) - } -} - -// This tests the fix for https://github.com/gregjones/httpcache/issues/74. -// Previously, after a stale response was used after encountering an error, -// its StatusCode was being incorrectly replaced. -func TestStaleIfErrorKeepsStatus(t *testing.T) { - resetTest() - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusNotFound), - StatusCode: http.StatusNotFound, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := NewMemoryCacheTransport() - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - r.Header.Set("Cache-Control", "stale-if-error") - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - if resp.StatusCode != http.StatusNotFound { - t.Fatalf("Status wasn't 404: %d", resp.StatusCode) - } -} - -// Test that http.Client.Timeout is respected when cache transport is used. -// That is so as long as request cancellation is propagated correctly. -// In the past, that required CancelRequest to be implemented correctly, -// but modern http.Client uses Request.Cancel (or request context) instead, -// so we don't have to do anything. -func TestClientTimeout(t *testing.T) { - if testing.Short() { - t.Skip("skipping timeout test in short mode") // Because it takes at least 3 seconds to run. - } - resetTest() - client := &http.Client{ - Transport: NewMemoryCacheTransport(), - Timeout: time.Second, - } - started := time.Now() - resp, err := client.Get(s.server.URL + "/3seconds") - taken := time.Since(started) - if err == nil { - t.Error("got nil error, want timeout error") - } - if resp != nil { - t.Error("got non-nil resp, want nil resp") - } - if taken >= 2*time.Second { - t.Error("client.Do took 2+ seconds, want < 2 seconds") - } -} diff --git a/libsq/core/ioz/httpcacheog/test/test.go b/libsq/core/ioz/httpcacheog/test/test.go deleted file mode 100644 index 8c6ff1350..000000000 --- a/libsq/core/ioz/httpcacheog/test/test.go +++ /dev/null @@ -1,36 +0,0 @@ -package test - -import ( - "bytes" - "context" - "testing" - - "github.com/neilotoole/sq/libsq/core/ioz/httpcacheog" -) - -// Cache excercises a httpcache.Cache implementation. -func Cache(t *testing.T, cache httpcache.Cache) { - key := "testKey" - _, ok := cache.Get(context.Background(), key) - if ok { - t.Fatal("retrieved key before adding it") - } - - val := []byte("some bytes") - cache.Set(context.Background(), key, val) - - retVal, ok := cache.Get(context.Background(), key) - if !ok { - t.Fatal("could not retrieve an element we just added") - } - if !bytes.Equal(retVal, val) { - t.Fatal("retrieved a different value than what we put in") - } - - cache.Delete(context.Background(), key) - - _, ok = cache.Get(context.Background(), key) - if ok { - t.Fatal("deleted key still present") - } -} diff --git a/libsq/core/ioz/httpcacheog/test/test_test.go b/libsq/core/ioz/httpcacheog/test/test_test.go deleted file mode 100644 index cc49e572e..000000000 --- a/libsq/core/ioz/httpcacheog/test/test_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package test_test - -import ( - "testing" - - "github.com/neilotoole/sq/libsq/core/ioz/httpcacheog" - "github.com/neilotoole/sq/libsq/core/ioz/httpcacheog/test" -) - -func TestMemoryCache(t *testing.T) { - test.Cache(t, httpcache.NewMemoryCache()) -} From 25ad2e4c00d446e0c625b5e8151b98b38b39c53a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 05:29:56 -0700 Subject: [PATCH 088/195] dl: configured for refactor --- go.mod | 2 -- go.sum | 4 ---- 2 files changed, 6 deletions(-) diff --git a/go.mod b/go.mod index 789748bec..0f011723e 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,6 @@ require ( github.com/a8m/tree v0.0.0-20230208161321-36ae24ddad15 github.com/alessio/shellescape v1.4.2 github.com/antlr4-go/antlr/v4 v4.13.0 - github.com/bitcomplete/httpcache v0.0.0-20220528171057-1f4a71bbffc5 github.com/c2h5oh/datasize v0.0.0-20220606134207-859f65c6625b github.com/dustin/go-humanize v1.0.1 github.com/ecnepsnai/osquery v1.0.1 @@ -68,7 +67,6 @@ require ( github.com/Masterminds/semver/v3 v3.2.0 // indirect github.com/VividCortex/ewma v1.2.0 // indirect github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d // indirect - github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/djherbis/atime v1.1.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect diff --git a/go.sum b/go.sum index ec24d70b0..11ea035fe 100644 --- a/go.sum +++ b/go.sum @@ -28,10 +28,6 @@ github.com/alessio/shellescape v1.4.2 h1:MHPfaU+ddJ0/bYWpgIeUnQUqKrlJ1S7BfEYPM4u github.com/alessio/shellescape v1.4.2/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= -github.com/bitcomplete/httpcache v0.0.0-20220528171057-1f4a71bbffc5 h1:W0eIasTpyKf6lU9jf5WhOpA0GcmnusfoL68U4VjSiwE= -github.com/bitcomplete/httpcache v0.0.0-20220528171057-1f4a71bbffc5/go.mod h1:bV6DTY4iwX8E6H3G//Ug6G3GmbLoFteBgBgmM9HYZDw= -github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 h1:N7oVaKyGp8bttX0bfZGmcGkjz7DLQXhAn3DNd3T0ous= -github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874/go.mod h1:r5xuitiExdLAJ09PR7vBVENGvp4ZuTBeWTGtxuX3K+c= github.com/c2h5oh/datasize v0.0.0-20220606134207-859f65c6625b h1:6+ZFm0flnudZzdSE0JxlhR2hKnGPcNB35BjQf4RYQDY= github.com/c2h5oh/datasize v0.0.0-20220606134207-859f65c6625b/go.mod h1:S/7n9copUssQ56c7aAgHqftWO4LTf4xY6CGWt8Bc+3M= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= From 0f69bf0c20497da466522ef23c8eda302cff5b5e Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 07:09:17 -0700 Subject: [PATCH 089/195] wip: RespCache --- libsq/core/ioz/httpcache/httpcache.go | 124 ++++++++++++++++++++++---- libsq/core/ioz/ioz.go | 24 +++++ libsq/core/ioz/ioz_test.go | 18 ++++ 3 files changed, 149 insertions(+), 17 deletions(-) diff --git a/libsq/core/ioz/httpcache/httpcache.go b/libsq/core/ioz/httpcache/httpcache.go index 51d283740..a3f38a681 100644 --- a/libsq/core/ioz/httpcache/httpcache.go +++ b/libsq/core/ioz/httpcache/httpcache.go @@ -3,7 +3,6 @@ // // It is only suitable for use as a 'private' cache (i.e. for a web-browser or an API-client // and not for a shared proxy). -// package httpcache import ( @@ -11,10 +10,16 @@ import ( "bytes" "context" "errors" + "github.com/neilotoole/sq/libsq/core/cleanup" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" "io" "io/ioutil" "net/http" "net/http/httputil" + "os" + "path/filepath" "strings" "sync" "time" @@ -50,9 +55,9 @@ var DefaultKeyFunc = func(req *http.Request) string { } } -// CachedResponse returns the cached http.Response for req if present, and nil +// CachedResponseOld returns the cached http.Response for req if present, and nil // otherwise. -func CachedResponse(ctx context.Context, c Cache, key string, req *http.Request) (resp *http.Response, err error) { +func CachedResponseOld(ctx context.Context, c Cache, key string, req *http.Request) (resp *http.Response, err error) { cachedVal, ok := c.Get(ctx, key) if !ok { return @@ -62,6 +67,57 @@ func CachedResponse(ctx context.Context, c Cache, key string, req *http.Request) return http.ReadResponse(bufio.NewReader(b), req) } +func NewRespCache(dir string) *RespCache { + c := &RespCache{ + Header: filepath.Join(dir, "header"), + Body: filepath.Join(dir, "body"), + clnup: cleanup.New(), + } + //c.clnup.AddE(func() error { + // return os.RemoveAll(dir) + //}) + return c +} + +type RespCache struct { + Header string + Body string + clnup *cleanup.Cleanup +} + +func (rc *RespCache) Cached(ctx context.Context, req *http.Request) (*http.Response, error) { + if !ioz.FileAccessible(rc.Header) { + return nil, nil + } + + b, err := os.ReadFile(rc.Header) + if err != nil { + return nil, err + } + + f, err := os.Open(rc.Body) + if err != nil { + lg.FromContext(ctx).Error("failed to open cached response body", + lga.File, rc.Body, lga.Err, err) + return nil, err + } + rc.clnup.AddC(f) + mr := io.MultiReader(bytes.NewReader(b), f) + return http.ReadResponse(bufio.NewReader(mr), req) +} + +func (t *Transport) CachedResponse(ctx context.Context, key string, req *http.Request) (resp *http.Response, err error) { + cachedVal, ok := t.Cache.Get(ctx, key) + if !ok { + return + } + + io.MultiReader() + + b := bytes.NewBuffer(cachedVal) + return http.ReadResponse(bufio.NewReader(b), req) +} + // MemoryCache is an implemtation of Cache that stores responses in an in-memory map. type MemoryCache struct { mu sync.RWMutex @@ -119,8 +175,9 @@ func KeyFuncOpt(keyFunc KeyFunc) TransportOpt { type Transport struct { // The RoundTripper interface actually used to make requests // If nil, http.DefaultTransport is used - Transport http.RoundTripper - Cache Cache + Transport http.RoundTripper + Cache Cache + BodyFilepath string // If true, responses returned from the cache will be given an extra header, X-From-Cache MarkCachedResponses bool // A function to generate a cache key for the given request @@ -171,10 +228,10 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" var cachedResp *http.Response if cacheable { - cachedResp, err = CachedResponse(req.Context(), t.Cache, cacheKey, req) + cachedResp, err = CachedResponseOld(req.Context(), t.Cache, cacheKey, req) } else { // Need to invalidate an existing value - t.Cache.Delete(req.Context(), cacheKey) + t.cacheDelete(req.Context(), cacheKey) } transport := t.Transport @@ -230,7 +287,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error return cachedResp, nil } else { if err != nil || resp.StatusCode != http.StatusOK { - t.Cache.Delete(req.Context(), cacheKey) + t.cacheDelete(req.Context(), cacheKey) } if err != nil { return nil, err @@ -265,24 +322,57 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error OnEOF: func(r io.Reader) { resp := *resp resp.Body = ioutil.NopCloser(r) - respBytes, err := httputil.DumpResponse(&resp, true) - if err == nil { - t.Cache.Set(req.Context(), cacheKey, respBytes) - } + _ = t.writeRespToCache(req.Context(), cacheKey, &resp) + //respBytes, err := httputil.DumpResponse(&resp, true) + //if err == nil { + // if _, err = ioz.WriteToFile(req.Context(), t.BodyFilepath, resp.Body); err != nil { + // lg.FromContext(req.Context()).Error("failed to write download cache body to file", + // lga.Err, err, lga.File, t.BodyFilepath) + // } else { + // t.Cache.Set(req.Context(), cacheKey, respBytes) + // } + //} }, } default: - respBytes, err := httputil.DumpResponse(resp, true) - if err == nil { - t.Cache.Set(req.Context(), cacheKey, respBytes) - } + _ = t.writeRespToCache(req.Context(), cacheKey, resp) + //respBytes, err := httputil.DumpResponse(resp, true) + //if err == nil { + // if _, err = ioz.WriteToFile(req.Context(), t.BodyFilepath, resp.Body); err != nil { + // lg.FromContext(req.Context()).Error("failed to write download cache body to file", + // lga.Err, err, lga.File, t.BodyFilepath) + // } else { + // t.Cache.Set(req.Context(), cacheKey, respBytes) + // } + //} } } else { - t.Cache.Delete(req.Context(), cacheKey) + t.cacheDelete(req.Context(), cacheKey) } return resp, nil } +func (t *Transport) cacheDelete(ctx context.Context, cacheKey string) { + if err := os.RemoveAll(t.BodyFilepath); err != nil { + lg.FromContext(ctx).Warn("failed to remove download cache body file", + lga.Err, err, lga.File, t.BodyFilepath) + } + t.Cache.Delete(ctx, cacheKey) +} + +func (t *Transport) writeRespToCache(ctx context.Context, cacheKey string, resp *http.Response) error { + respBytes, err := httputil.DumpResponse(resp, true) + if err == nil { + if _, err = ioz.WriteToFile(ctx, t.BodyFilepath, resp.Body); err != nil { + lg.FromContext(ctx).Error("failed to write download cache body to file", + lga.Err, err, lga.File, t.BodyFilepath) + } else { + t.Cache.Set(ctx, cacheKey, respBytes) + } + } + return err +} + // ErrNoDateHeader indicates that the HTTP headers contained no Date header. var ErrNoDateHeader = errors.New("no Date header") diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 61aafa5a1..118aa399a 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -6,6 +6,7 @@ import ( "context" crand "crypto/rand" "crypto/tls" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" "io" mrand "math/rand" "net/http" @@ -418,3 +419,26 @@ func NewHTTPClient(insecureSkipVerify bool) *http.Client { return &client } + +// WriteToFile writes the contents of r to fp. If fp doesn't exist, +// the file is created (including any parent dirs). If fp exists, it is +// truncated. The write operation is context-aware. +func WriteToFile(ctx context.Context, fp string, r io.Reader) (written int64, err error) { + if err = RequireDir(filepath.Dir(fp)); err != nil { + return 0, err + } + + f, err := os.Create(fp) + if err != nil { + return 0, err + } + + cr := contextio.NewReader(ctx, r) + written, err = io.Copy(f, cr) + closeErr := f.Close() + if err == nil { + return written, closeErr + } + + return written, err +} diff --git a/libsq/core/ioz/ioz_test.go b/libsq/core/ioz/ioz_test.go index 4f5ecc63e..5013c1e9f 100644 --- a/libsq/core/ioz/ioz_test.go +++ b/libsq/core/ioz/ioz_test.go @@ -2,8 +2,11 @@ package ioz_test import ( "bytes" + "context" "io" "os" + "path/filepath" + "strings" "sync" "testing" "time" @@ -81,3 +84,18 @@ func TestDelayReader(t *testing.T) { wg.Wait() } + +func TestWriteToFile(t *testing.T) { + const val = `In Zanadu did Kubla Khan a stately pleasure dome decree` + ctx := context.Background() + dir := t.TempDir() + + fp := filepath.Join(dir, "not_existing_intervening_dir", "test.txt") + written, err := ioz.WriteToFile(ctx, fp, strings.NewReader(val)) + require.NoError(t, err) + require.Equal(t, int64(len(val)), written) + + got, err := os.ReadFile(fp) + require.NoError(t, err) + require.Equal(t, val, string(got)) +} From d303338e81c83a20e8e326db451d6c5b1b77e2f2 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 08:09:47 -0700 Subject: [PATCH 090/195] wip: adding RespCache --- libsq/core/ioz/httpcache/httpcache.go | 106 +++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 2 deletions(-) diff --git a/libsq/core/ioz/httpcache/httpcache.go b/libsq/core/ioz/httpcache/httpcache.go index a3f38a681..e3ece6773 100644 --- a/libsq/core/ioz/httpcache/httpcache.go +++ b/libsq/core/ioz/httpcache/httpcache.go @@ -11,6 +11,7 @@ import ( "context" "errors" "github.com/neilotoole/sq/libsq/core/cleanup" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" @@ -80,12 +81,16 @@ func NewRespCache(dir string) *RespCache { } type RespCache struct { + mu sync.Mutex Header string Body string clnup *cleanup.Cleanup } func (rc *RespCache) Cached(ctx context.Context, req *http.Request) (*http.Response, error) { + rc.mu.Lock() + defer rc.mu.Unlock() + if !ioz.FileAccessible(rc.Header) { return nil, nil } @@ -106,14 +111,111 @@ func (rc *RespCache) Cached(ctx context.Context, req *http.Request) (*http.Respo return http.ReadResponse(bufio.NewReader(mr), req) } +func (rc *RespCache) Close() error { + rc.mu.Lock() + defer rc.mu.Unlock() + err := rc.clnup.Run() + rc.clnup = cleanup.New() + return err +} + +func (rc *RespCache) Clear() error { + rc.mu.Lock() + defer rc.mu.Unlock() + return rc.doClear() +} + +func (rc *RespCache) doClear() error { + err1 := rc.clnup.Run() + rc.clnup = cleanup.New() + err2 := os.RemoveAll(rc.Header) + err3 := os.RemoveAll(rc.Body) + return errz.Combine(err1, err2, err3) +} + +//// drainBody reads all of b to memory and then returns two equivalent +//// ReadClosers yielding the same bytes. +//// +//// It returns an error if the initial slurp of all bytes fails. It does not attempt +//// to make the returned ReadClosers have identical error-matching behavior. +//func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { +// if b == nil || b == http.NoBody { +// // No copying needed. Preserve the magic sentinel meaning of NoBody. +// return http.NoBody, http.NoBody, nil +// } +// var buf bytes.Buffer +// if _, err = buf.ReadFrom(b); err != nil { +// return nil, b, err +// } +// if err = b.Close(); err != nil { +// return nil, b, err +// } +// return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil +//} + +const msgClearCache = "Clear HTTP response cache" + +func (rc *RespCache) Write(ctx context.Context, resp *http.Response) error { + rc.mu.Lock() + defer rc.mu.Unlock() + + err := rc.doWrite(ctx, resp) + if err != nil { + //lg.WarnIfFuncError(lg.FromContext(ctx), msgClearCache, rc.doClear) + } + return err +} + +func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response) error { + log := lg.FromContext(ctx) + + if err := ioz.RequireDir(filepath.Dir(rc.Header)); err != nil { + return err + } + + if err := ioz.RequireDir(filepath.Dir(rc.Body)); err != nil { + return err + } + + respBytes, err := httputil.DumpResponse(resp, false) + if err != nil { + return err + } + + if _, err = ioz.WriteToFile(ctx, rc.Header, bytes.NewReader(respBytes)); err != nil { + return err + } + + f, err := os.OpenFile(rc.Body, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) + if err != nil { + return err + } + + _, err = io.Copy(f, resp.Body) + if err != nil { + lg.WarnIfCloseError(log, "Close cache body file", f) + return err + } + + if err = f.Close(); err != nil { + return err + } + + f, err = os.Open(rc.Body) + if err != nil { + return err + } + + resp.Body = f + return nil +} + func (t *Transport) CachedResponse(ctx context.Context, key string, req *http.Request) (resp *http.Response, err error) { cachedVal, ok := t.Cache.Get(ctx, key) if !ok { return } - io.MultiReader() - b := bytes.NewBuffer(cachedVal) return http.ReadResponse(bufio.NewReader(b), req) } From 8a0699c3dc12bdf3d011f31fc7cff3d4e85b572d Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 08:40:04 -0700 Subject: [PATCH 091/195] wip: RespCache seems to be working --- libsq/core/ioz/httpcache/httpcache.go | 90 +++++++++++++++------------ 1 file changed, 50 insertions(+), 40 deletions(-) diff --git a/libsq/core/ioz/httpcache/httpcache.go b/libsq/core/ioz/httpcache/httpcache.go index e3ece6773..990b884bb 100644 --- a/libsq/core/ioz/httpcache/httpcache.go +++ b/libsq/core/ioz/httpcache/httpcache.go @@ -13,6 +13,7 @@ import ( "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "io" @@ -87,6 +88,8 @@ type RespCache struct { clnup *cleanup.Cleanup } +// Cached returns the cached http.Response for req if present, and nil +// otherwise. func (rc *RespCache) Cached(ctx context.Context, req *http.Request) (*http.Response, error) { rc.mu.Lock() defer rc.mu.Unlock() @@ -95,37 +98,46 @@ func (rc *RespCache) Cached(ctx context.Context, req *http.Request) (*http.Respo return nil, nil } - b, err := os.ReadFile(rc.Header) + headerBytes, err := os.ReadFile(rc.Header) if err != nil { return nil, err } - f, err := os.Open(rc.Body) + bodyFile, err := os.Open(rc.Body) if err != nil { lg.FromContext(ctx).Error("failed to open cached response body", lga.File, rc.Body, lga.Err, err) return nil, err } - rc.clnup.AddC(f) - mr := io.MultiReader(bytes.NewReader(b), f) - return http.ReadResponse(bufio.NewReader(mr), req) + + // We need to explicitly close bodyFile at some later point. It won't be + // closed via a call to http.Response.Body.Close(). + rc.clnup.AddC(bodyFile) + + concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) + return http.ReadResponse(bufio.NewReader(concatRdr), req) } +// Close closes the cache, freeing any resources it holds. Note that +// it does not delete the cache: for that, see RespCache.Delete. func (rc *RespCache) Close() error { rc.mu.Lock() defer rc.mu.Unlock() + err := rc.clnup.Run() rc.clnup = cleanup.New() return err } -func (rc *RespCache) Clear() error { +// Delete deletes the cache. +func (rc *RespCache) Delete() error { rc.mu.Lock() defer rc.mu.Unlock() - return rc.doClear() + + return rc.doDelete() } -func (rc *RespCache) doClear() error { +func (rc *RespCache) doDelete() error { err1 := rc.clnup.Run() rc.clnup = cleanup.New() err2 := os.RemoveAll(rc.Header) @@ -133,35 +145,16 @@ func (rc *RespCache) doClear() error { return errz.Combine(err1, err2, err3) } -//// drainBody reads all of b to memory and then returns two equivalent -//// ReadClosers yielding the same bytes. -//// -//// It returns an error if the initial slurp of all bytes fails. It does not attempt -//// to make the returned ReadClosers have identical error-matching behavior. -//func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { -// if b == nil || b == http.NoBody { -// // No copying needed. Preserve the magic sentinel meaning of NoBody. -// return http.NoBody, http.NoBody, nil -// } -// var buf bytes.Buffer -// if _, err = buf.ReadFrom(b); err != nil { -// return nil, b, err -// } -// if err = b.Close(); err != nil { -// return nil, b, err -// } -// return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil -//} - -const msgClearCache = "Clear HTTP response cache" +const msgDeleteCache = "Delete HTTP response cache" +// Write writes resp to the cache. func (rc *RespCache) Write(ctx context.Context, resp *http.Response) error { rc.mu.Lock() defer rc.mu.Unlock() err := rc.doWrite(ctx, resp) if err != nil { - //lg.WarnIfFuncError(lg.FromContext(ctx), msgClearCache, rc.doClear) + lg.WarnIfFuncError(lg.FromContext(ctx), msgDeleteCache, rc.doDelete) } return err } @@ -191,7 +184,8 @@ func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response) error { return err } - _, err = io.Copy(f, resp.Body) + cr := contextio.NewReader(ctx, resp.Body) + _, err = io.Copy(f, cr) if err != nil { lg.WarnIfCloseError(log, "Close cache body file", f) return err @@ -277,8 +271,11 @@ func KeyFuncOpt(keyFunc KeyFunc) TransportOpt { type Transport struct { // The RoundTripper interface actually used to make requests // If nil, http.DefaultTransport is used - Transport http.RoundTripper - Cache Cache + Transport http.RoundTripper + Cache Cache + RespCache *RespCache + + // Deprecated: Use RespCache instead. BodyFilepath string // If true, responses returned from the cache will be given an extra header, X-From-Cache MarkCachedResponses bool @@ -326,14 +323,18 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { // to give the server a chance to respond with NotModified. If this happens, then the cached Response // will be returned. func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { - cacheKey := t.KeyFunc(req) + log := lg.FromContext(req.Context()) + //cacheKey := t.KeyFunc(req) cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" var cachedResp *http.Response if cacheable { - cachedResp, err = CachedResponseOld(req.Context(), t.Cache, cacheKey, req) + cachedResp, err = t.RespCache.Cached(req.Context(), req) + //cachedResp, err = t.CachedResponse(req.Context(), cacheKey, req) } else { // Need to invalidate an existing value - t.cacheDelete(req.Context(), cacheKey) + //err = t.RespCache.Delete() + lg.WarnIfFuncError(log, "Delete cached response", t.RespCache.Delete) + //t.cacheDelete(req.Context(), cacheKey) } transport := t.Transport @@ -389,7 +390,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error return cachedResp, nil } else { if err != nil || resp.StatusCode != http.StatusOK { - t.cacheDelete(req.Context(), cacheKey) + //t.cacheDelete(req.Context(), cacheKey) + t.RespCache.Delete() + //t.cacheDelete(req.Context(), cacheKey) } if err != nil { return nil, err @@ -424,7 +427,10 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error OnEOF: func(r io.Reader) { resp := *resp resp.Body = ioutil.NopCloser(r) - _ = t.writeRespToCache(req.Context(), cacheKey, &resp) + if err := t.RespCache.Write(req.Context(), &resp); err != nil { + log.Error("failed to write download cache", lga.Err, err) + } + //_ = t.writeRespToCache(req.Context(), cacheKey, &resp) //respBytes, err := httputil.DumpResponse(&resp, true) //if err == nil { // if _, err = ioz.WriteToFile(req.Context(), t.BodyFilepath, resp.Body); err != nil { @@ -437,7 +443,10 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error }, } default: - _ = t.writeRespToCache(req.Context(), cacheKey, resp) + if err := t.RespCache.Write(req.Context(), resp); err != nil { + log.Error("failed to write download cache", lga.Err, err) + } + //_ = t.writeRespToCache(req.Context(), cacheKey, resp) //respBytes, err := httputil.DumpResponse(resp, true) //if err == nil { // if _, err = ioz.WriteToFile(req.Context(), t.BodyFilepath, resp.Body); err != nil { @@ -449,7 +458,8 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error //} } } else { - t.cacheDelete(req.Context(), cacheKey) + lg.WarnIfFuncError(log, "Delete resp cache", t.RespCache.Delete) + //t.cacheDelete(req.Context(), cacheKey) } return resp, nil } From caf87443cabd616e5c9190de81c8cbdbaa1e8463 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 09:24:33 -0700 Subject: [PATCH 092/195] wip: almost there --- libsq/core/ioz/httpcache/httpcache.go | 185 +++++++++------------ libsq/core/ioz/httpcache/httpcache_test.go | 94 ++++++----- 2 files changed, 127 insertions(+), 152 deletions(-) diff --git a/libsq/core/ioz/httpcache/httpcache.go b/libsq/core/ioz/httpcache/httpcache.go index 990b884bb..b49a48029 100644 --- a/libsq/core/ioz/httpcache/httpcache.go +++ b/libsq/core/ioz/httpcache/httpcache.go @@ -48,49 +48,34 @@ type Cache interface { type KeyFunc func(req *http.Request) string -// DefaultKeyFunc returns the cache key for req -var DefaultKeyFunc = func(req *http.Request) string { - if req.Method == http.MethodGet { - return req.URL.String() - } else { - return req.Method + " " + req.URL.String() - } -} - -// CachedResponseOld returns the cached http.Response for req if present, and nil -// otherwise. -func CachedResponseOld(ctx context.Context, c Cache, key string, req *http.Request) (resp *http.Response, err error) { - cachedVal, ok := c.Get(ctx, key) - if !ok { - return - } - - b := bytes.NewBuffer(cachedVal) - return http.ReadResponse(bufio.NewReader(b), req) -} - -func NewRespCache(dir string) *RespCache { +// NewRespCache returns a new instance that stores responses in cacheDir. +// The caller should call RespCache.Close when finished with the cache. +func NewRespCache(cacheDir string) *RespCache { c := &RespCache{ - Header: filepath.Join(dir, "header"), - Body: filepath.Join(dir, "body"), + Header: filepath.Join(cacheDir, "header"), + Body: filepath.Join(cacheDir, "body"), clnup: cleanup.New(), } - //c.clnup.AddE(func() error { - // return os.RemoveAll(dir) - //}) return c } +// RespCache is a cache for a single http.Response. The response is +// stored in two files, one for the header and one for the body. +// The caller should call RespCache.Close when finished with the cache. type RespCache struct { - mu sync.Mutex + mu sync.Mutex + clnup *cleanup.Cleanup + + // Header is the path to the file containing the http.Response header. Header string - Body string - clnup *cleanup.Cleanup + + // Body is the path to the file containing the http.Response body. + Body string } -// Cached returns the cached http.Response for req if present, and nil +// Get returns the cached http.Response for req if present, and nil // otherwise. -func (rc *RespCache) Cached(ctx context.Context, req *http.Request) (*http.Response, error) { +func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response, error) { rc.mu.Lock() defer rc.mu.Unlock() @@ -129,8 +114,11 @@ func (rc *RespCache) Close() error { return err } -// Delete deletes the cache. +// Delete deletes the cache entries from disk. func (rc *RespCache) Delete() error { + if rc == nil { + return nil + } rc.mu.Lock() defer rc.mu.Unlock() @@ -204,15 +192,15 @@ func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response) error { return nil } -func (t *Transport) CachedResponse(ctx context.Context, key string, req *http.Request) (resp *http.Response, err error) { - cachedVal, ok := t.Cache.Get(ctx, key) - if !ok { - return - } - - b := bytes.NewBuffer(cachedVal) - return http.ReadResponse(bufio.NewReader(b), req) -} +//func (t *Transport) CachedResponse(ctx context.Context, key string, req *http.Request) (resp *http.Response, err error) { +// cachedVal, ok := t.Cache.Get(ctx, key) +// if !ok { +// return +// } +// +// b := bytes.NewBuffer(cachedVal) +// return http.ReadResponse(bufio.NewReader(b), req) +//} // MemoryCache is an implemtation of Cache that stores responses in an in-memory map. type MemoryCache struct { @@ -258,12 +246,12 @@ func MarkCachedResponsesOpt(markCachedResponses bool) TransportOpt { } } -// KeyFuncOpt configures a transport by setting its KeyFunc to the one given -func KeyFuncOpt(keyFunc KeyFunc) TransportOpt { - return func(t *Transport) { - t.KeyFunc = keyFunc - } -} +//// KeyFuncOpt configures a transport by setting its KeyFunc to the one given +//func KeyFuncOpt(keyFunc KeyFunc) TransportOpt { +// return func(t *Transport) { +// t.KeyFunc = keyFunc +// } +//} // Transport is an implementation of http.RoundTripper that will return values from a cache // where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) @@ -272,23 +260,27 @@ type Transport struct { // The RoundTripper interface actually used to make requests // If nil, http.DefaultTransport is used Transport http.RoundTripper - Cache Cache + //Cache Cache RespCache *RespCache // Deprecated: Use RespCache instead. - BodyFilepath string - // If true, responses returned from the cache will be given an extra header, X-From-Cache + //BodyFilepath string + + // MarkCachedResponses, if true, indicates that responses returned from the + // cache will be given an extra header, X-From-Cache MarkCachedResponses bool + // A function to generate a cache key for the given request - KeyFunc KeyFunc + //KeyFunc KeyFunc } // NewTransport returns a new Transport with the provided Cache and options. If // KeyFunc is not specified in opts then DefaultKeyFunc is used. -func NewTransport(c Cache, opts ...TransportOpt) *Transport { +func NewTransport(rc *RespCache, opts ...TransportOpt) *Transport { t := &Transport{ - Cache: c, - KeyFunc: DefaultKeyFunc, + //Cache: c, + //KeyFunc: DefaultKeyFunc, + RespCache: rc, MarkCachedResponses: true, } for _, opt := range opts { @@ -328,13 +320,12 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" var cachedResp *http.Response if cacheable { - cachedResp, err = t.RespCache.Cached(req.Context(), req) - //cachedResp, err = t.CachedResponse(req.Context(), cacheKey, req) + cachedResp, err = t.RespCache.Get(req.Context(), req) } else { // Need to invalidate an existing value - //err = t.RespCache.Delete() - lg.WarnIfFuncError(log, "Delete cached response", t.RespCache.Delete) - //t.cacheDelete(req.Context(), cacheKey) + if err = t.RespCache.Delete(); err != nil { + return nil, err + } } transport := t.Transport @@ -390,9 +381,9 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error return cachedResp, nil } else { if err != nil || resp.StatusCode != http.StatusOK { - //t.cacheDelete(req.Context(), cacheKey) - t.RespCache.Delete() - //t.cacheDelete(req.Context(), cacheKey) + if err = t.RespCache.Delete(); err != nil { + return nil, err + } } if err != nil { return nil, err @@ -430,32 +421,12 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error if err := t.RespCache.Write(req.Context(), &resp); err != nil { log.Error("failed to write download cache", lga.Err, err) } - //_ = t.writeRespToCache(req.Context(), cacheKey, &resp) - //respBytes, err := httputil.DumpResponse(&resp, true) - //if err == nil { - // if _, err = ioz.WriteToFile(req.Context(), t.BodyFilepath, resp.Body); err != nil { - // lg.FromContext(req.Context()).Error("failed to write download cache body to file", - // lga.Err, err, lga.File, t.BodyFilepath) - // } else { - // t.Cache.Set(req.Context(), cacheKey, respBytes) - // } - //} }, } default: - if err := t.RespCache.Write(req.Context(), resp); err != nil { + if err = t.RespCache.Write(req.Context(), resp); err != nil { log.Error("failed to write download cache", lga.Err, err) } - //_ = t.writeRespToCache(req.Context(), cacheKey, resp) - //respBytes, err := httputil.DumpResponse(resp, true) - //if err == nil { - // if _, err = ioz.WriteToFile(req.Context(), t.BodyFilepath, resp.Body); err != nil { - // lg.FromContext(req.Context()).Error("failed to write download cache body to file", - // lga.Err, err, lga.File, t.BodyFilepath) - // } else { - // t.Cache.Set(req.Context(), cacheKey, respBytes) - // } - //} } } else { lg.WarnIfFuncError(log, "Delete resp cache", t.RespCache.Delete) @@ -464,26 +435,26 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error return resp, nil } -func (t *Transport) cacheDelete(ctx context.Context, cacheKey string) { - if err := os.RemoveAll(t.BodyFilepath); err != nil { - lg.FromContext(ctx).Warn("failed to remove download cache body file", - lga.Err, err, lga.File, t.BodyFilepath) - } - t.Cache.Delete(ctx, cacheKey) -} - -func (t *Transport) writeRespToCache(ctx context.Context, cacheKey string, resp *http.Response) error { - respBytes, err := httputil.DumpResponse(resp, true) - if err == nil { - if _, err = ioz.WriteToFile(ctx, t.BodyFilepath, resp.Body); err != nil { - lg.FromContext(ctx).Error("failed to write download cache body to file", - lga.Err, err, lga.File, t.BodyFilepath) - } else { - t.Cache.Set(ctx, cacheKey, respBytes) - } - } - return err -} +//func (t *Transport) cacheDelete(ctx context.Context, cacheKey string) { +// if err := os.RemoveAll(t.BodyFilepath); err != nil { +// lg.FromContext(ctx).Warn("failed to remove download cache body file", +// lga.Err, err, lga.File, t.BodyFilepath) +// } +// t.Cache.Delete(ctx, cacheKey) +//} +// +//func (t *Transport) writeRespToCache(ctx context.Context, cacheKey string, resp *http.Response) error { +// respBytes, err := httputil.DumpResponse(resp, true) +// if err == nil { +// if _, err = ioz.WriteToFile(ctx, t.BodyFilepath, resp.Body); err != nil { +// lg.FromContext(ctx).Error("failed to write download cache body to file", +// lga.Err, err, lga.File, t.BodyFilepath) +// } else { +// t.Cache.Set(ctx, cacheKey, respBytes) +// } +// } +// return err +//} // ErrNoDateHeader indicates that the HTTP headers contained no Date header. var ErrNoDateHeader = errors.New("no Date header") @@ -779,8 +750,8 @@ func (r *cachingReadCloser) Close() error { } // NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation -func NewMemoryCacheTransport(opts ...TransportOpt) *Transport { - c := NewMemoryCache() - t := NewTransport(c, opts...) +func NewMemoryCacheTransport(cacheDir string, opts ...TransportOpt) *Transport { + rc := NewRespCache(cacheDir) + t := NewTransport(rc, opts...) return t } diff --git a/libsq/core/ioz/httpcache/httpcache_test.go b/libsq/core/ioz/httpcache/httpcache_test.go index 9fb3b8f47..6a30ca263 100644 --- a/libsq/core/ioz/httpcache/httpcache_test.go +++ b/libsq/core/ioz/httpcache/httpcache_test.go @@ -4,11 +4,13 @@ import ( "bytes" "errors" "flag" + "github.com/neilotoole/sq/libsq/core/stringz" "io" "io/ioutil" "net/http" "net/http/httptest" "os" + "path/filepath" "strconv" "testing" "time" @@ -38,7 +40,7 @@ func TestMain(m *testing.M) { } func setup() { - tp := NewMemoryCacheTransport() + tp := NewMemoryCacheTransport(filepath.Join(os.TempDir(), stringz.Uniq8())) client := http.Client{Transport: tp} s.transport = tp s.client = client @@ -165,15 +167,16 @@ func teardown() { s.server.Close() } -func resetTest() { - s.transport.Cache = NewMemoryCache() +func resetTest(t testing.TB) { + s.transport.RespCache = NewRespCache(t.TempDir()) + //s.transport.RespCache.Delete() clock = &realClock{} } // TestCacheableMethod ensures that uncacheable method does not get stored // in cache and get incorrectly used for a following cacheable method request. func TestCacheableMethod(t *testing.T) { - resetTest() + resetTest(t) { req, err := http.NewRequest("POST", s.server.URL+"/method", nil) if err != nil { @@ -230,7 +233,7 @@ func TestCacheableMethod(t *testing.T) { } func TestDontServeHeadResponseToGetRequest(t *testing.T) { - resetTest() + resetTest(t) url := s.server.URL + "/" req, err := http.NewRequest(http.MethodHead, url, nil) if err != nil { @@ -254,7 +257,7 @@ func TestDontServeHeadResponseToGetRequest(t *testing.T) { } func TestDontStorePartialRangeInCache(t *testing.T) { - resetTest() + resetTest(t) { req, err := http.NewRequest("GET", s.server.URL+"/range", nil) if err != nil { @@ -366,7 +369,7 @@ func TestDontStorePartialRangeInCache(t *testing.T) { } func TestCacheOnlyIfBodyRead(t *testing.T) { - resetTest() + resetTest(t) { req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { @@ -399,7 +402,7 @@ func TestCacheOnlyIfBodyRead(t *testing.T) { } func TestOnlyReadBodyOnDemand(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil) if err != nil { @@ -418,7 +421,7 @@ func TestOnlyReadBodyOnDemand(t *testing.T) { } func TestGetOnlyIfCachedHit(t *testing.T) { - resetTest() + resetTest(t) { req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { @@ -458,7 +461,7 @@ func TestGetOnlyIfCachedHit(t *testing.T) { } func TestGetOnlyIfCachedMiss(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { t.Fatal(err) @@ -478,7 +481,7 @@ func TestGetOnlyIfCachedMiss(t *testing.T) { } func TestGetNoStoreRequest(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL, nil) if err != nil { t.Fatal(err) @@ -507,7 +510,7 @@ func TestGetNoStoreRequest(t *testing.T) { } func TestGetNoStoreResponse(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) if err != nil { t.Fatal(err) @@ -535,7 +538,7 @@ func TestGetNoStoreResponse(t *testing.T) { } func TestGetWithEtag(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) if err != nil { t.Fatal(err) @@ -575,7 +578,7 @@ func TestGetWithEtag(t *testing.T) { } func TestGetWithLastModified(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) if err != nil { t.Fatal(err) @@ -607,7 +610,7 @@ func TestGetWithLastModified(t *testing.T) { } func TestGetWithVary(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) if err != nil { t.Fatal(err) @@ -662,7 +665,7 @@ func TestGetWithVary(t *testing.T) { } func TestGetWithDoubleVary(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) if err != nil { t.Fatal(err) @@ -718,7 +721,7 @@ func TestGetWithDoubleVary(t *testing.T) { } func TestGetWith2VaryHeaders(t *testing.T) { - resetTest() + resetTest(t) // Tests that multiple Vary headers' comma-separated lists are // merged. See https://github.com/gregjones/httpcache/issues/27. const ( @@ -817,7 +820,7 @@ func TestGetWith2VaryHeaders(t *testing.T) { } func TestGetVaryUnused(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) if err != nil { t.Fatal(err) @@ -850,7 +853,7 @@ func TestGetVaryUnused(t *testing.T) { } func TestUpdateFields(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) if err != nil { t.Fatal(err) @@ -888,7 +891,7 @@ func TestUpdateFields(t *testing.T) { // Previously, after validating a cached response, its StatusCode // was incorrectly being replaced. func TestCachedErrorsKeepStatus(t *testing.T) { - resetTest() + resetTest(t) req, err := http.NewRequest("GET", s.server.URL+"/cachederror", nil) if err != nil { t.Fatal(err) @@ -914,7 +917,7 @@ func TestCachedErrorsKeepStatus(t *testing.T) { } func TestParseCacheControl(t *testing.T) { - resetTest() + resetTest(t) h := http.Header{} for range parseCacheControl(h) { t.Fatal("cacheControl should be empty") @@ -951,7 +954,7 @@ func TestParseCacheControl(t *testing.T) { } func TestNoCacheRequestExpiration(t *testing.T) { - resetTest() + resetTest(t) respHeaders := http.Header{} respHeaders.Set("Cache-Control", "max-age=7200") @@ -963,7 +966,7 @@ func TestNoCacheRequestExpiration(t *testing.T) { } func TestNoCacheResponseExpiration(t *testing.T) { - resetTest() + resetTest(t) respHeaders := http.Header{} respHeaders.Set("Cache-Control", "no-cache") respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") @@ -975,7 +978,7 @@ func TestNoCacheResponseExpiration(t *testing.T) { } func TestReqMustRevalidate(t *testing.T) { - resetTest() + resetTest(t) // not paying attention to request setting max-stale means never returning stale // responses, so always acting as if must-revalidate is set respHeaders := http.Header{} @@ -988,7 +991,7 @@ func TestReqMustRevalidate(t *testing.T) { } func TestRespMustRevalidate(t *testing.T) { - resetTest() + resetTest(t) respHeaders := http.Header{} respHeaders.Set("Cache-Control", "must-revalidate") @@ -999,7 +1002,7 @@ func TestRespMustRevalidate(t *testing.T) { } func TestFreshExpiration(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -1017,7 +1020,7 @@ func TestFreshExpiration(t *testing.T) { } func TestMaxAge(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -1035,7 +1038,7 @@ func TestMaxAge(t *testing.T) { } func TestMaxAgeZero(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -1048,7 +1051,7 @@ func TestMaxAgeZero(t *testing.T) { } func TestBothMaxAge(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -1062,7 +1065,7 @@ func TestBothMaxAge(t *testing.T) { } func TestMinFreshWithExpires(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -1082,7 +1085,7 @@ func TestMinFreshWithExpires(t *testing.T) { } func TestEmptyMaxStale(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -1102,7 +1105,7 @@ func TestEmptyMaxStale(t *testing.T) { } func TestMaxStaleValue(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() respHeaders := http.Header{} respHeaders.Set("date", now.Format(time.RFC1123)) @@ -1136,7 +1139,7 @@ func containsHeader(headers []string, header string) bool { } func TestGetEndToEndHeaders(t *testing.T) { - resetTest() + resetTest(t) var ( headers http.Header end2end []string @@ -1193,7 +1196,7 @@ func (t transportMock) RoundTrip(req *http.Request) (resp *http.Response, err er } func TestStaleIfErrorRequest(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() tmock := transportMock{ response: &http.Response{ @@ -1207,7 +1210,7 @@ func TestStaleIfErrorRequest(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport() + tp := NewMemoryCacheTransport(t.TempDir()) tp.Transport = &tmock // First time, response is cached on success @@ -1238,7 +1241,7 @@ func TestStaleIfErrorRequest(t *testing.T) { } func TestStaleIfErrorRequestLifetime(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() tmock := transportMock{ response: &http.Response{ @@ -1252,7 +1255,7 @@ func TestStaleIfErrorRequestLifetime(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport() + tp := NewMemoryCacheTransport(t.TempDir()) tp.Transport = &tmock // First time, response is cached on success @@ -1301,7 +1304,7 @@ func TestStaleIfErrorRequestLifetime(t *testing.T) { } func TestStaleIfErrorResponse(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() tmock := transportMock{ response: &http.Response{ @@ -1315,7 +1318,7 @@ func TestStaleIfErrorResponse(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport() + tp := NewMemoryCacheTransport(t.TempDir()) tp.Transport = &tmock // First time, response is cached on success @@ -1345,7 +1348,7 @@ func TestStaleIfErrorResponse(t *testing.T) { } func TestStaleIfErrorResponseLifetime(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() tmock := transportMock{ response: &http.Response{ @@ -1359,7 +1362,7 @@ func TestStaleIfErrorResponseLifetime(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport() + tp := NewMemoryCacheTransport(t.TempDir()) tp.Transport = &tmock // First time, response is cached on success @@ -1399,7 +1402,7 @@ func TestStaleIfErrorResponseLifetime(t *testing.T) { // Previously, after a stale response was used after encountering an error, // its StatusCode was being incorrectly replaced. func TestStaleIfErrorKeepsStatus(t *testing.T) { - resetTest() + resetTest(t) now := time.Now() tmock := transportMock{ response: &http.Response{ @@ -1413,7 +1416,7 @@ func TestStaleIfErrorKeepsStatus(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport() + tp := NewMemoryCacheTransport(t.TempDir()) tp.Transport = &tmock // First time, response is cached on success @@ -1455,9 +1458,10 @@ func TestClientTimeout(t *testing.T) { if testing.Short() { t.Skip("skipping timeout test in short mode") // Because it takes at least 3 seconds to run. } - resetTest() + resetTest(t) + client := &http.Client{ - Transport: NewMemoryCacheTransport(), + Transport: NewMemoryCacheTransport(t.TempDir()), Timeout: time.Second, } started := time.Now() From 8d099e5613d9d4f2adbbf1a2a137a914eb9cf381 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 09:27:14 -0700 Subject: [PATCH 093/195] wip: almost there --- .../core/ioz/httpcache/diskcache/diskcache.go | 63 ------------------- .../ioz/httpcache/diskcache/diskcache_test.go | 18 ------ 2 files changed, 81 deletions(-) delete mode 100644 libsq/core/ioz/httpcache/diskcache/diskcache.go delete mode 100644 libsq/core/ioz/httpcache/diskcache/diskcache_test.go diff --git a/libsq/core/ioz/httpcache/diskcache/diskcache.go b/libsq/core/ioz/httpcache/diskcache/diskcache.go deleted file mode 100644 index 4dd96e128..000000000 --- a/libsq/core/ioz/httpcache/diskcache/diskcache.go +++ /dev/null @@ -1,63 +0,0 @@ -// Package diskcache provides an implementation of httpcache.Cache that uses the diskv package -// to supplement an in-memory map with persistent storage -// -package diskcache - -import ( - "bytes" - "context" - "crypto/md5" - "encoding/hex" - "io" - - "github.com/peterbourgon/diskv" -) - -// Cache is an implementation of httpcache.Cache that supplements the in-memory map with persistent storage -type Cache struct { - d *diskv.Diskv -} - -// Get returns the response corresponding to key if present -func (c *Cache) Get(ctx context.Context, key string) (resp []byte, ok bool) { - key = keyToFilename(key) - resp, err := c.d.Read(key) - if err != nil { - return []byte{}, false - } - return resp, true -} - -// Set saves a response to the cache as key -func (c *Cache) Set(ctx context.Context, key string, resp []byte) { - key = keyToFilename(key) - _ = c.d.WriteStream(key, bytes.NewReader(resp), true) -} - -// Delete removes the response with key from the cache -func (c *Cache) Delete(ctx context.Context, key string) { - key = keyToFilename(key) - _ = c.d.Erase(key) -} - -func keyToFilename(key string) string { - h := md5.New() - _, _ = io.WriteString(h, key) - return hex.EncodeToString(h.Sum(nil)) -} - -// New returns a new Cache that will store files in basePath -func New(basePath string) *Cache { - return &Cache{ - d: diskv.New(diskv.Options{ - BasePath: basePath, - CacheSizeMax: 100 * 1024 * 1024, // 100MB - }), - } -} - -// NewWithDiskv returns a new Cache using the provided Diskv as underlying -// storage. -func NewWithDiskv(d *diskv.Diskv) *Cache { - return &Cache{d} -} diff --git a/libsq/core/ioz/httpcache/diskcache/diskcache_test.go b/libsq/core/ioz/httpcache/diskcache/diskcache_test.go deleted file mode 100644 index 3fe82273d..000000000 --- a/libsq/core/ioz/httpcache/diskcache/diskcache_test.go +++ /dev/null @@ -1,18 +0,0 @@ -package diskcache - -import ( - "github.com/neilotoole/sq/libsq/core/ioz/httpcache/test" - "io/ioutil" - "os" - "testing" -) - -func TestDiskCache(t *testing.T) { - tempDir, err := ioutil.TempDir("", "httpcache") - if err != nil { - t.Fatalf("TempDir: %v", err) - } - defer os.RemoveAll(tempDir) - - test.Cache(t, New(tempDir)) -} From db62fc31f6d5e4912c769a54bed82ebf26c6c85f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 10:31:20 -0700 Subject: [PATCH 094/195] refactoring to restore key --- libsq/core/ioz/httpcache/httpcache.go | 270 +++------------------ libsq/core/ioz/httpcache/httpcache_test.go | 21 +- libsq/core/ioz/httpcache/respcache.go | 175 +++++++++++++ libsq/core/ioz/httpcache/test/test.go | 36 --- libsq/core/ioz/httpcache/test/test_test.go | 12 - 5 files changed, 224 insertions(+), 290 deletions(-) create mode 100644 libsq/core/ioz/httpcache/respcache.go delete mode 100644 libsq/core/ioz/httpcache/test/test.go delete mode 100644 libsq/core/ioz/httpcache/test/test_test.go diff --git a/libsq/core/ioz/httpcache/httpcache.go b/libsq/core/ioz/httpcache/httpcache.go index b49a48029..008a911c6 100644 --- a/libsq/core/ioz/httpcache/httpcache.go +++ b/libsq/core/ioz/httpcache/httpcache.go @@ -10,20 +10,12 @@ import ( "bytes" "context" "errors" - "github.com/neilotoole/sq/libsq/core/cleanup" - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "io" "io/ioutil" "net/http" - "net/http/httputil" - "os" - "path/filepath" "strings" - "sync" "time" ) @@ -48,193 +40,40 @@ type Cache interface { type KeyFunc func(req *http.Request) string -// NewRespCache returns a new instance that stores responses in cacheDir. -// The caller should call RespCache.Close when finished with the cache. -func NewRespCache(cacheDir string) *RespCache { - c := &RespCache{ - Header: filepath.Join(cacheDir, "header"), - Body: filepath.Join(cacheDir, "body"), - clnup: cleanup.New(), - } - return c -} - -// RespCache is a cache for a single http.Response. The response is -// stored in two files, one for the header and one for the body. -// The caller should call RespCache.Close when finished with the cache. -type RespCache struct { - mu sync.Mutex - clnup *cleanup.Cleanup - - // Header is the path to the file containing the http.Response header. - Header string - - // Body is the path to the file containing the http.Response body. - Body string -} - -// Get returns the cached http.Response for req if present, and nil -// otherwise. -func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response, error) { - rc.mu.Lock() - defer rc.mu.Unlock() - - if !ioz.FileAccessible(rc.Header) { - return nil, nil - } - - headerBytes, err := os.ReadFile(rc.Header) - if err != nil { - return nil, err - } - - bodyFile, err := os.Open(rc.Body) - if err != nil { - lg.FromContext(ctx).Error("failed to open cached response body", - lga.File, rc.Body, lga.Err, err) - return nil, err - } - - // We need to explicitly close bodyFile at some later point. It won't be - // closed via a call to http.Response.Body.Close(). - rc.clnup.AddC(bodyFile) - - concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) - return http.ReadResponse(bufio.NewReader(concatRdr), req) -} - -// Close closes the cache, freeing any resources it holds. Note that -// it does not delete the cache: for that, see RespCache.Delete. -func (rc *RespCache) Close() error { - rc.mu.Lock() - defer rc.mu.Unlock() - - err := rc.clnup.Run() - rc.clnup = cleanup.New() - return err -} - -// Delete deletes the cache entries from disk. -func (rc *RespCache) Delete() error { - if rc == nil { - return nil - } - rc.mu.Lock() - defer rc.mu.Unlock() - - return rc.doDelete() -} - -func (rc *RespCache) doDelete() error { - err1 := rc.clnup.Run() - rc.clnup = cleanup.New() - err2 := os.RemoveAll(rc.Header) - err3 := os.RemoveAll(rc.Body) - return errz.Combine(err1, err2, err3) -} - -const msgDeleteCache = "Delete HTTP response cache" - -// Write writes resp to the cache. -func (rc *RespCache) Write(ctx context.Context, resp *http.Response) error { - rc.mu.Lock() - defer rc.mu.Unlock() - - err := rc.doWrite(ctx, resp) - if err != nil { - lg.WarnIfFuncError(lg.FromContext(ctx), msgDeleteCache, rc.doDelete) - } - return err -} - -func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response) error { - log := lg.FromContext(ctx) - - if err := ioz.RequireDir(filepath.Dir(rc.Header)); err != nil { - return err - } - - if err := ioz.RequireDir(filepath.Dir(rc.Body)); err != nil { - return err - } - - respBytes, err := httputil.DumpResponse(resp, false) - if err != nil { - return err - } - - if _, err = ioz.WriteToFile(ctx, rc.Header, bytes.NewReader(respBytes)); err != nil { - return err - } - - f, err := os.OpenFile(rc.Body, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) - if err != nil { - return err - } - - cr := contextio.NewReader(ctx, resp.Body) - _, err = io.Copy(f, cr) - if err != nil { - lg.WarnIfCloseError(log, "Close cache body file", f) - return err - } - - if err = f.Close(); err != nil { - return err - } - - f, err = os.Open(rc.Body) - if err != nil { - return err - } - - resp.Body = f - return nil -} - -//func (t *Transport) CachedResponse(ctx context.Context, key string, req *http.Request) (resp *http.Response, err error) { -// cachedVal, ok := t.Cache.Get(ctx, key) -// if !ok { -// return -// } // -// b := bytes.NewBuffer(cachedVal) -// return http.ReadResponse(bufio.NewReader(b), req) +//// MemoryCache is an implemtation of Cache that stores responses in an in-memory map. +//type MemoryCache struct { +// mu sync.RWMutex +// items map[string][]byte +//} +// +//// Get returns the []byte representation of the response and true if present, false if not +//func (c *MemoryCache) Get(ctx context.Context, key string) (resp []byte, ok bool) { +// c.mu.RLock() +// resp, ok = c.items[key] +// c.mu.RUnlock() +// return resp, ok +//} +// +//// Set saves response resp to the cache with key +//func (c *MemoryCache) Set(ctx context.Context, key string, resp []byte) { +// c.mu.Lock() +// c.items[key] = resp +// c.mu.Unlock() +//} +// +//// Delete removes key from the cache +//func (c *MemoryCache) Delete(ctx context.Context, key string) { +// c.mu.Lock() +// delete(c.items, key) +// c.mu.Unlock() +//} +// +//// NewMemoryCache returns a new Cache that will store items in an in-memory map +//func NewMemoryCache() *MemoryCache { +// c := &MemoryCache{items: map[string][]byte{}} +// return c //} - -// MemoryCache is an implemtation of Cache that stores responses in an in-memory map. -type MemoryCache struct { - mu sync.RWMutex - items map[string][]byte -} - -// Get returns the []byte representation of the response and true if present, false if not -func (c *MemoryCache) Get(ctx context.Context, key string) (resp []byte, ok bool) { - c.mu.RLock() - resp, ok = c.items[key] - c.mu.RUnlock() - return resp, ok -} - -// Set saves response resp to the cache with key -func (c *MemoryCache) Set(ctx context.Context, key string, resp []byte) { - c.mu.Lock() - c.items[key] = resp - c.mu.Unlock() -} - -// Delete removes key from the cache -func (c *MemoryCache) Delete(ctx context.Context, key string) { - c.mu.Lock() - delete(c.items, key) - c.mu.Unlock() -} - -// NewMemoryCache returns a new Cache that will store items in an in-memory map -func NewMemoryCache() *MemoryCache { - c := &MemoryCache{items: map[string][]byte{}} - return c -} // TransportOpt is a configuration option for creating a new Transport type TransportOpt func(t *Transport) @@ -260,26 +99,18 @@ type Transport struct { // The RoundTripper interface actually used to make requests // If nil, http.DefaultTransport is used Transport http.RoundTripper - //Cache Cache - RespCache *RespCache - // Deprecated: Use RespCache instead. - //BodyFilepath string + RespCache *RespCache // MarkCachedResponses, if true, indicates that responses returned from the // cache will be given an extra header, X-From-Cache MarkCachedResponses bool - - // A function to generate a cache key for the given request - //KeyFunc KeyFunc } // NewTransport returns a new Transport with the provided Cache and options. If // KeyFunc is not specified in opts then DefaultKeyFunc is used. func NewTransport(rc *RespCache, opts ...TransportOpt) *Transport { t := &Transport{ - //Cache: c, - //KeyFunc: DefaultKeyFunc, RespCache: rc, MarkCachedResponses: true, } @@ -316,7 +147,7 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { // will be returned. func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { log := lg.FromContext(req.Context()) - //cacheKey := t.KeyFunc(req) + cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" var cachedResp *http.Response if cacheable { @@ -381,9 +212,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error return cachedResp, nil } else { if err != nil || resp.StatusCode != http.StatusOK { - if err = t.RespCache.Delete(); err != nil { - return nil, err - } + lg.WarnIfFuncError(log, msgDeleteCache, t.RespCache.Delete) } if err != nil { return nil, err @@ -430,32 +259,10 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error } } else { lg.WarnIfFuncError(log, "Delete resp cache", t.RespCache.Delete) - //t.cacheDelete(req.Context(), cacheKey) } return resp, nil } -//func (t *Transport) cacheDelete(ctx context.Context, cacheKey string) { -// if err := os.RemoveAll(t.BodyFilepath); err != nil { -// lg.FromContext(ctx).Warn("failed to remove download cache body file", -// lga.Err, err, lga.File, t.BodyFilepath) -// } -// t.Cache.Delete(ctx, cacheKey) -//} -// -//func (t *Transport) writeRespToCache(ctx context.Context, cacheKey string, resp *http.Response) error { -// respBytes, err := httputil.DumpResponse(resp, true) -// if err == nil { -// if _, err = ioz.WriteToFile(ctx, t.BodyFilepath, resp.Body); err != nil { -// lg.FromContext(ctx).Error("failed to write download cache body to file", -// lga.Err, err, lga.File, t.BodyFilepath) -// } else { -// t.Cache.Set(ctx, cacheKey, respBytes) -// } -// } -// return err -//} - // ErrNoDateHeader indicates that the HTTP headers contained no Date header. var ErrNoDateHeader = errors.New("no Date header") @@ -748,10 +555,3 @@ func (r *cachingReadCloser) Read(p []byte) (n int, err error) { func (r *cachingReadCloser) Close() error { return r.R.Close() } - -// NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation -func NewMemoryCacheTransport(cacheDir string, opts ...TransportOpt) *Transport { - rc := NewRespCache(cacheDir) - t := NewTransport(rc, opts...) - return t -} diff --git a/libsq/core/ioz/httpcache/httpcache_test.go b/libsq/core/ioz/httpcache/httpcache_test.go index 6a30ca263..48b9a73fc 100644 --- a/libsq/core/ioz/httpcache/httpcache_test.go +++ b/libsq/core/ioz/httpcache/httpcache_test.go @@ -16,6 +16,13 @@ import ( "time" ) +// newTestTransport returns a new Transport using the in-memory cache implementation +func newTestTransport(cacheDir string, opts ...TransportOpt) *Transport { + rc := NewRespCache(cacheDir) + t := NewTransport(rc, opts...) + return t +} + var s struct { server *httptest.Server client http.Client @@ -40,7 +47,7 @@ func TestMain(m *testing.M) { } func setup() { - tp := NewMemoryCacheTransport(filepath.Join(os.TempDir(), stringz.Uniq8())) + tp := newTestTransport(filepath.Join(os.TempDir(), stringz.Uniq8())) client := http.Client{Transport: tp} s.transport = tp s.client = client @@ -1210,7 +1217,7 @@ func TestStaleIfErrorRequest(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport(t.TempDir()) + tp := newTestTransport(t.TempDir()) tp.Transport = &tmock // First time, response is cached on success @@ -1255,7 +1262,7 @@ func TestStaleIfErrorRequestLifetime(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport(t.TempDir()) + tp := newTestTransport(t.TempDir()) tp.Transport = &tmock // First time, response is cached on success @@ -1318,7 +1325,7 @@ func TestStaleIfErrorResponse(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport(t.TempDir()) + tp := newTestTransport(t.TempDir()) tp.Transport = &tmock // First time, response is cached on success @@ -1362,7 +1369,7 @@ func TestStaleIfErrorResponseLifetime(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport(t.TempDir()) + tp := newTestTransport(t.TempDir()) tp.Transport = &tmock // First time, response is cached on success @@ -1416,7 +1423,7 @@ func TestStaleIfErrorKeepsStatus(t *testing.T) { }, err: nil, } - tp := NewMemoryCacheTransport(t.TempDir()) + tp := newTestTransport(t.TempDir()) tp.Transport = &tmock // First time, response is cached on success @@ -1461,7 +1468,7 @@ func TestClientTimeout(t *testing.T) { resetTest(t) client := &http.Client{ - Transport: NewMemoryCacheTransport(t.TempDir()), + Transport: newTestTransport(t.TempDir()), Timeout: time.Second, } started := time.Now() diff --git a/libsq/core/ioz/httpcache/respcache.go b/libsq/core/ioz/httpcache/respcache.go new file mode 100644 index 000000000..54a5666a6 --- /dev/null +++ b/libsq/core/ioz/httpcache/respcache.go @@ -0,0 +1,175 @@ +package httpcache + +import ( + "bufio" + "bytes" + "context" + "github.com/neilotoole/sq/libsq/core/cleanup" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "io" + "net/http" + "net/http/httputil" + "os" + "path/filepath" + "sync" +) + +// NewRespCache returns a new instance that stores responses in cacheDir. +// The caller should call RespCache.Close when finished with the cache. +func NewRespCache(cacheDir string) *RespCache { + c := &RespCache{ + Dir: cacheDir, + Header: filepath.Join(cacheDir, "header"), + Body: filepath.Join(cacheDir, "body"), + clnup: cleanup.New(), + } + return c +} + +// RespCache is a cache for a single http.Response. The response is +// stored in two files, one for the header and one for the body. +// The caller should call RespCache.Close when finished with the cache. +type RespCache struct { + mu sync.Mutex + clnup *cleanup.Cleanup + + Dir string + + // Header is the path to the file containing the http.Response header. + Header string + + // Body is the path to the file containing the http.Response body. + Body string +} + +func (rc *RespCache) getPaths(req *http.Request) (header, body string) { + if req.Method == http.MethodGet { + return filepath.Join(rc.Dir, "header"), filepath.Join(rc.Dir, "body") + } + + return filepath.Join(rc.Dir, req.Method+"_header"), + filepath.Join(rc.Dir, req.Method+"_body") +} + +// Get returns the cached http.Response for req if present, and nil +// otherwise. +func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response, error) { + rc.mu.Lock() + defer rc.mu.Unlock() + + if !ioz.FileAccessible(rc.Header) { + return nil, nil + } + + headerBytes, err := os.ReadFile(rc.Header) + if err != nil { + return nil, err + } + + bodyFile, err := os.Open(rc.Body) + if err != nil { + lg.FromContext(ctx).Error("failed to open cached response body", + lga.File, rc.Body, lga.Err, err) + return nil, err + } + + // We need to explicitly close bodyFile at some later point. It won't be + // closed via a call to http.Response.Body.Close(). + rc.clnup.AddC(bodyFile) + + concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) + return http.ReadResponse(bufio.NewReader(concatRdr), req) +} + +// Close closes the cache, freeing any resources it holds. Note that +// it does not delete the cache: for that, see RespCache.Delete. +func (rc *RespCache) Close() error { + rc.mu.Lock() + defer rc.mu.Unlock() + + err := rc.clnup.Run() + rc.clnup = cleanup.New() + return err +} + +// Delete deletes the cache entries from disk. +func (rc *RespCache) Delete() error { + if rc == nil { + return nil + } + rc.mu.Lock() + defer rc.mu.Unlock() + + return rc.doDelete() +} + +func (rc *RespCache) doDelete() error { + err1 := rc.clnup.Run() + rc.clnup = cleanup.New() + err2 := os.RemoveAll(rc.Header) + err3 := os.RemoveAll(rc.Body) + return errz.Combine(err1, err2, err3) +} + +const msgDeleteCache = "Delete HTTP response cache" + +// Write writes resp to the cache. +func (rc *RespCache) Write(ctx context.Context, resp *http.Response) error { + rc.mu.Lock() + defer rc.mu.Unlock() + + err := rc.doWrite(ctx, resp) + if err != nil { + lg.WarnIfFuncError(lg.FromContext(ctx), msgDeleteCache, rc.doDelete) + } + return err +} + +func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response) error { + log := lg.FromContext(ctx) + + if err := ioz.RequireDir(filepath.Dir(rc.Header)); err != nil { + return err + } + + if err := ioz.RequireDir(filepath.Dir(rc.Body)); err != nil { + return err + } + + respBytes, err := httputil.DumpResponse(resp, false) + if err != nil { + return err + } + + if _, err = ioz.WriteToFile(ctx, rc.Header, bytes.NewReader(respBytes)); err != nil { + return err + } + + f, err := os.OpenFile(rc.Body, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) + if err != nil { + return err + } + + cr := contextio.NewReader(ctx, resp.Body) + _, err = io.Copy(f, cr) + if err != nil { + lg.WarnIfCloseError(log, "Close cache body file", f) + return err + } + + if err = f.Close(); err != nil { + return err + } + + f, err = os.Open(rc.Body) + if err != nil { + return err + } + + resp.Body = f + return nil +} diff --git a/libsq/core/ioz/httpcache/test/test.go b/libsq/core/ioz/httpcache/test/test.go deleted file mode 100644 index 533c60aaf..000000000 --- a/libsq/core/ioz/httpcache/test/test.go +++ /dev/null @@ -1,36 +0,0 @@ -package test - -import ( - "bytes" - "context" - "testing" - - "github.com/neilotoole/sq/libsq/core/ioz/httpcache" -) - -// Cache excercises a httpcache.Cache implementation. -func Cache(t *testing.T, cache httpcache.Cache) { - key := "testKey" - _, ok := cache.Get(context.Background(), key) - if ok { - t.Fatal("retrieved key before adding it") - } - - val := []byte("some bytes") - cache.Set(context.Background(), key, val) - - retVal, ok := cache.Get(context.Background(), key) - if !ok { - t.Fatal("could not retrieve an element we just added") - } - if !bytes.Equal(retVal, val) { - t.Fatal("retrieved a different value than what we put in") - } - - cache.Delete(context.Background(), key) - - _, ok = cache.Get(context.Background(), key) - if ok { - t.Fatal("deleted key still present") - } -} diff --git a/libsq/core/ioz/httpcache/test/test_test.go b/libsq/core/ioz/httpcache/test/test_test.go deleted file mode 100644 index 4f02f62a1..000000000 --- a/libsq/core/ioz/httpcache/test/test_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package test_test - -import ( - "testing" - - "github.com/neilotoole/sq/libsq/core/ioz/httpcache" - "github.com/neilotoole/sq/libsq/core/ioz/httpcache/test" -) - -func TestMemoryCache(t *testing.T) { - test.Cache(t, httpcache.NewMemoryCache()) -} From c81b6928cd23d29240c98ea8446fc5461509e378 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 10:45:27 -0700 Subject: [PATCH 095/195] httpcache: all tests passing --- libsq/core/ioz/httpcache/respcache.go | 45 +++++++++++---------------- libsq/core/lg/lga/lga.go | 1 + 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/libsq/core/ioz/httpcache/respcache.go b/libsq/core/ioz/httpcache/respcache.go index 54a5666a6..464ebad18 100644 --- a/libsq/core/ioz/httpcache/respcache.go +++ b/libsq/core/ioz/httpcache/respcache.go @@ -22,10 +22,10 @@ import ( // The caller should call RespCache.Close when finished with the cache. func NewRespCache(cacheDir string) *RespCache { c := &RespCache{ - Dir: cacheDir, - Header: filepath.Join(cacheDir, "header"), - Body: filepath.Join(cacheDir, "body"), - clnup: cleanup.New(), + Dir: cacheDir, + //Header: filepath.Join(cacheDir, "header"), + //Body: filepath.Join(cacheDir, "body"), + clnup: cleanup.New(), } return c } @@ -38,16 +38,10 @@ type RespCache struct { clnup *cleanup.Cleanup Dir string - - // Header is the path to the file containing the http.Response header. - Header string - - // Body is the path to the file containing the http.Response body. - Body string } func (rc *RespCache) getPaths(req *http.Request) (header, body string) { - if req.Method == http.MethodGet { + if req == nil || req.Method == http.MethodGet { return filepath.Join(rc.Dir, "header"), filepath.Join(rc.Dir, "body") } @@ -61,19 +55,21 @@ func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response rc.mu.Lock() defer rc.mu.Unlock() - if !ioz.FileAccessible(rc.Header) { + fpHeader, fpBody := rc.getPaths(req) + + if !ioz.FileAccessible(fpHeader) { return nil, nil } - headerBytes, err := os.ReadFile(rc.Header) + headerBytes, err := os.ReadFile(fpHeader) if err != nil { return nil, err } - bodyFile, err := os.Open(rc.Body) + bodyFile, err := os.Open(fpBody) if err != nil { lg.FromContext(ctx).Error("failed to open cached response body", - lga.File, rc.Body, lga.Err, err) + lga.File, fpBody, lga.Err, err) return nil, err } @@ -108,11 +104,10 @@ func (rc *RespCache) Delete() error { } func (rc *RespCache) doDelete() error { - err1 := rc.clnup.Run() + cleanErr := rc.clnup.Run() rc.clnup = cleanup.New() - err2 := os.RemoveAll(rc.Header) - err3 := os.RemoveAll(rc.Body) - return errz.Combine(err1, err2, err3) + deleteErr := os.RemoveAll(rc.Dir) + return errz.Combine(cleanErr, deleteErr) } const msgDeleteCache = "Delete HTTP response cache" @@ -132,24 +127,22 @@ func (rc *RespCache) Write(ctx context.Context, resp *http.Response) error { func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response) error { log := lg.FromContext(ctx) - if err := ioz.RequireDir(filepath.Dir(rc.Header)); err != nil { + if err := ioz.RequireDir(rc.Dir); err != nil { return err } - if err := ioz.RequireDir(filepath.Dir(rc.Body)); err != nil { - return err - } + fpHeader, fpBody := rc.getPaths(resp.Request) respBytes, err := httputil.DumpResponse(resp, false) if err != nil { return err } - if _, err = ioz.WriteToFile(ctx, rc.Header, bytes.NewReader(respBytes)); err != nil { + if _, err = ioz.WriteToFile(ctx, fpHeader, bytes.NewReader(respBytes)); err != nil { return err } - f, err := os.OpenFile(rc.Body, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) + f, err := os.OpenFile(fpBody, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) if err != nil { return err } @@ -165,7 +158,7 @@ func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response) error { return err } - f, err = os.Open(rc.Body) + f, err = os.Open(fpBody) if err != nil { return err } diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index 72eb96d81..03d9ef511 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -35,6 +35,7 @@ const ( Kind = "kind" Loc = "loc" Lock = "lock" + Method = "method" Name = "name" New = "new" Old = "old" From 4c0f4f5398ba7af19e0809a37253024f3aaa2e8a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 10:52:37 -0700 Subject: [PATCH 096/195] httpcache: all tests passing --- libsq/core/ioz/httpcache/httpcache.go | 71 ++++----------------------- 1 file changed, 9 insertions(+), 62 deletions(-) diff --git a/libsq/core/ioz/httpcache/httpcache.go b/libsq/core/ioz/httpcache/httpcache.go index 008a911c6..99d12faef 100644 --- a/libsq/core/ioz/httpcache/httpcache.go +++ b/libsq/core/ioz/httpcache/httpcache.go @@ -1,14 +1,16 @@ -// Package httpcache provides a http.RoundTripper implementation that works as a -// mostly RFC-compliant cache for http responses. +// Package httpcache provides a http.RoundTripper implementation that +// works as a mostly RFC-compliant cache for http responses. // -// It is only suitable for use as a 'private' cache (i.e. for a web-browser or an API-client -// and not for a shared proxy). +// FIXME: move httpcache to internal/httpcache, because its use +// is so specialized? +// +// Acknowledgement: This package is a heavily customized fork +// of https://github.com/gregjones/httpcache, via bitcomplete/httpcache. package httpcache import ( "bufio" "bytes" - "context" "errors" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" @@ -27,54 +29,6 @@ const ( XFromCache = "X-From-Cache" ) -// A Cache interface is used by the Transport to store and retrieve responses. -type Cache interface { - // Get returns the []byte representation of a cached response and a bool - // set to true if the value isn't empty - Get(ctx context.Context, key string) (responseBytes []byte, ok bool) - // Set stores the []byte representation of a response against a key - Set(ctx context.Context, key string, responseBytes []byte) - // Delete removes the value associated with the key - Delete(ctx context.Context, key string) -} - -type KeyFunc func(req *http.Request) string - -// -//// MemoryCache is an implemtation of Cache that stores responses in an in-memory map. -//type MemoryCache struct { -// mu sync.RWMutex -// items map[string][]byte -//} -// -//// Get returns the []byte representation of the response and true if present, false if not -//func (c *MemoryCache) Get(ctx context.Context, key string) (resp []byte, ok bool) { -// c.mu.RLock() -// resp, ok = c.items[key] -// c.mu.RUnlock() -// return resp, ok -//} -// -//// Set saves response resp to the cache with key -//func (c *MemoryCache) Set(ctx context.Context, key string, resp []byte) { -// c.mu.Lock() -// c.items[key] = resp -// c.mu.Unlock() -//} -// -//// Delete removes key from the cache -//func (c *MemoryCache) Delete(ctx context.Context, key string) { -// c.mu.Lock() -// delete(c.items, key) -// c.mu.Unlock() -//} -// -//// NewMemoryCache returns a new Cache that will store items in an in-memory map -//func NewMemoryCache() *MemoryCache { -// c := &MemoryCache{items: map[string][]byte{}} -// return c -//} - // TransportOpt is a configuration option for creating a new Transport type TransportOpt func(t *Transport) @@ -85,13 +39,6 @@ func MarkCachedResponsesOpt(markCachedResponses bool) TransportOpt { } } -//// KeyFuncOpt configures a transport by setting its KeyFunc to the one given -//func KeyFuncOpt(keyFunc KeyFunc) TransportOpt { -// return func(t *Transport) { -// t.KeyFunc = keyFunc -// } -//} - // Transport is an implementation of http.RoundTripper that will return values from a cache // where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) // to repeated requests allowing servers to return 304 / Not Modified @@ -125,8 +72,8 @@ func (t *Transport) Client() *http.Client { return &http.Client{Transport: t} } -// varyMatches will return false unless all of the cached values for the headers listed in Vary -// match the new request +// varyMatches will return false unless all the cached values for the +// headers listed in Vary match the new request func varyMatches(cachedResp *http.Response, req *http.Request) bool { for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { header = http.CanonicalHeaderKey(header) From ce19d4e9f841c53f76317a9e4ccc1e294c04f3f9 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 14:23:46 -0700 Subject: [PATCH 097/195] saving progress: httpcacheworking --- libsq/core/ioz/httpcache/httpcache.go | 223 ++- libsq/core/ioz/httpcache/httpz.go | 88 + libsq/core/ioz/httpcache/respcache.go | 88 +- libsq/core/ioz/httpcacheworking/LICENSE.txt | 7 + libsq/core/ioz/httpcacheworking/README.md | 42 + libsq/core/ioz/httpcacheworking/httpcache.go | 717 ++++++++ .../ioz/httpcacheworking/httpcache_test.go | 1486 +++++++++++++++++ libsq/core/ioz/httpcacheworking/httpz.go | 88 + libsq/core/ioz/httpcacheworking/respcache.go | 210 +++ 9 files changed, 2921 insertions(+), 28 deletions(-) create mode 100644 libsq/core/ioz/httpcache/httpz.go create mode 100644 libsq/core/ioz/httpcacheworking/LICENSE.txt create mode 100644 libsq/core/ioz/httpcacheworking/README.md create mode 100644 libsq/core/ioz/httpcacheworking/httpcache.go create mode 100644 libsq/core/ioz/httpcacheworking/httpcache_test.go create mode 100644 libsq/core/ioz/httpcacheworking/httpz.go create mode 100644 libsq/core/ioz/httpcacheworking/respcache.go diff --git a/libsq/core/ioz/httpcache/httpcache.go b/libsq/core/ioz/httpcache/httpcache.go index 99d12faef..41f63706d 100644 --- a/libsq/core/ioz/httpcache/httpcache.go +++ b/libsq/core/ioz/httpcache/httpcache.go @@ -12,11 +12,13 @@ import ( "bufio" "bytes" "errors" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "io" "io/ioutil" "net/http" + "os" "strings" "time" ) @@ -84,6 +86,217 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { return true } +// IsCached returns true if there is a cache entry for req. This does not +// guarantee that the cache entry is fresh. +func (t *Transport) IsCached(req *http.Request) bool { + return t.RespCache.Exists(req) +} + +// IsFresh returns true if there is a fresh cache entry for req. +func (t *Transport) IsFresh(req *http.Request) bool { + ctx := req.Context() + log := lg.FromContext(ctx) + + if !isCacheable(req) { + return false + } + + if !t.RespCache.Exists(req) { + return false + } + + fpHeader, _ := t.RespCache.Paths(req) + f, err := os.Open(fpHeader) + if err != nil { + log.Error("Failed to open cached response header file", lga.File, fpHeader, lga.Err, err) + return false + } + + defer lg.WarnIfCloseError(log, "Close cached response header", f) + + cachedResp, err := ReadResponse(bufio.NewReader(f), nil, true) + if err != nil { + log.Error("Failed to read cached response", lga.Err, err) + return false + } + + freshness := getFreshness(cachedResp.Header, req.Header) + return freshness == fresh +} + +func isCacheable(req *http.Request) bool { + return (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get("range") == "" +} + +type CallbackHandler struct { + HandleCached func(cachedFilepath string) error + HandleUncached func() (wc io.WriteCloser, errFn func(error), err error) + HandleError func(err error) +} + +func (t *Transport) Fetch(req *http.Request, cb CallbackHandler) { + ctx := req.Context() + log := lg.FromContext(ctx) + log.Info("Fetching download", lga.URL, req.URL.String()) + _ = log + _, fpBody := t.RespCache.Paths(req) + + if t.IsFresh(req) { + _ = cb.HandleCached(fpBody) + return + } + + var err error + cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" + var cachedResp *http.Response + if cacheable { + cachedResp, err = t.RespCache.Get(req.Context(), req) + } else { + // Need to invalidate an existing value + if err = t.RespCache.Delete(req.Context()); err != nil { + cb.HandleError(err) + return + } + } + + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + var resp *http.Response + if cacheable && cachedResp != nil && err == nil { + if t.MarkCachedResponses { + cachedResp.Header.Set(XFromCache, "1") + } + + if varyMatches(cachedResp, req) { + // Can only use cached value if the new request doesn't Vary significantly + freshness := getFreshness(cachedResp.Header, req.Header) + if freshness == fresh { + _ = cb.HandleCached(fpBody) + return + } + + if freshness == stale { + var req2 *http.Request + // Add validators if caller hasn't already done so + etag := cachedResp.Header.Get("etag") + if etag != "" && req.Header.Get("etag") == "" { + req2 = cloneRequest(req) + req2.Header.Set("if-none-match", etag) + } + lastModified := cachedResp.Header.Get("last-modified") + if lastModified != "" && req.Header.Get("last-modified") == "" { + if req2 == nil { + req2 = cloneRequest(req) + } + req2.Header.Set("if-modified-since", lastModified) + } + if req2 != nil { + req = req2 + } + } + } + + // FIXME: Use an http client here + resp, err = transport.RoundTrip(req) + if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { + // Replace the 304 response with the one from cache, but update with some new headers + endToEndHeaders := getEndToEndHeaders(resp.Header) + for _, header := range endToEndHeaders { + cachedResp.Header[header] = resp.Header[header] + } + resp = cachedResp + } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && + req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header) { + // In case of transport failure and stale-if-error activated, returns cached content + // when available + log.Warn("Returning cached response due to transport failure", lga.Err, err) + cb.HandleCached(fpBody) + return + } else { + if err != nil || resp.StatusCode != http.StatusOK { + lg.WarnIfError(log, msgDeleteCache, t.RespCache.Delete(req.Context())) + } + if err != nil { + cb.HandleError(err) + return + } + } + } else { + reqCacheControl := parseCacheControl(req.Header) + if _, ok := reqCacheControl["only-if-cached"]; ok { + resp = newGatewayTimeoutResponse(req) + } else { + resp, err = transport.RoundTrip(req) + if err != nil { + cb.HandleError(err) + return + } + } + } + + if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { + for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { + varyKey = http.CanonicalHeaderKey(varyKey) + fakeHeader := "X-Varied-" + varyKey + reqValue := req.Header.Get(varyKey) + if reqValue != "" { + resp.Header.Set(fakeHeader, reqValue) + } + } + switch req.Method { + //case "GET": + // // Delay caching until EOF is reached. + // resp.Body = &cachingReadCloser{ + // R: resp.Body, + // OnEOF: func(r io.Reader) { + // resp := *resp + // resp.Body = ioutil.NopCloser(r) + // if err := t.RespCache.Write(req.Context(), &resp, nil); err != nil { + // log.Error("failed to write download cache", lga.Err, err) + // } + // }, + // } + default: + copyWrtr, errFn, err := cb.HandleUncached() + if err != nil { + cb.HandleError(err) + return + } + + if err = t.RespCache.Write(req.Context(), resp, copyWrtr); err != nil { + log.Error("failed to write download cache", lga.Err, err) + errFn(err) + cb.HandleError(err) + } + return + } + } else { + lg.WarnIfError(log, "Delete resp cache", t.RespCache.Delete(req.Context())) + } + + // It's not cacheable, so we need to write it to the copyWrtr + copyWrtr, errFn, err := cb.HandleUncached() + if err != nil { + cb.HandleError(err) + return + } + cr := contextio.NewReader(ctx, resp.Body) + _, err = io.Copy(copyWrtr, cr) + if err != nil { + errFn(err) + cb.HandleError(err) + return + } + if err = copyWrtr.Close(); err != nil { + cb.HandleError(err) + return + } + + return +} + // RoundTrip takes a Request and returns a Response // // If there is a fresh Response already in cache, then it will be returned without connecting to @@ -101,7 +314,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error cachedResp, err = t.RespCache.Get(req.Context(), req) } else { // Need to invalidate an existing value - if err = t.RespCache.Delete(); err != nil { + if err = t.RespCache.Delete(req.Context()); err != nil { return nil, err } } @@ -159,7 +372,7 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error return cachedResp, nil } else { if err != nil || resp.StatusCode != http.StatusOK { - lg.WarnIfFuncError(log, msgDeleteCache, t.RespCache.Delete) + lg.WarnIfError(log, msgDeleteCache, t.RespCache.Delete(req.Context())) } if err != nil { return nil, err @@ -194,18 +407,18 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error OnEOF: func(r io.Reader) { resp := *resp resp.Body = ioutil.NopCloser(r) - if err := t.RespCache.Write(req.Context(), &resp); err != nil { + if err := t.RespCache.Write(req.Context(), &resp, nil); err != nil { log.Error("failed to write download cache", lga.Err, err) } }, } default: - if err = t.RespCache.Write(req.Context(), resp); err != nil { + if err = t.RespCache.Write(req.Context(), resp, nil); err != nil { log.Error("failed to write download cache", lga.Err, err) } } } else { - lg.WarnIfFuncError(log, "Delete resp cache", t.RespCache.Delete) + lg.WarnIfError(log, "Delete resp cache", t.RespCache.Delete(req.Context())) } return resp, nil } diff --git a/libsq/core/ioz/httpcache/httpz.go b/libsq/core/ioz/httpcache/httpz.go new file mode 100644 index 000000000..8e5337684 --- /dev/null +++ b/libsq/core/ioz/httpcache/httpz.go @@ -0,0 +1,88 @@ +package httpcache + +import ( + "bufio" + "fmt" + "io" + "net/http" + "net/textproto" + "strconv" + "strings" +) + +// ReadResponse is a copy of http.ReadResponse, but with the option +// to read only the response header, and not the body. When only reading +// the header, note that resp.Body will be nil, and that the resp is +// generally not functional. +func ReadResponse(r *bufio.Reader, req *http.Request, headerOnly bool) (*http.Response, error) { + if !headerOnly { + return http.ReadResponse(r, req) + } + tp := textproto.NewReader(r) + resp := &http.Response{ + Request: req, + } + + // Parse the first line of the response. + line, err := tp.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + proto, status, ok := strings.Cut(line, " ") + if !ok { + return nil, badStringError("malformed HTTP response", line) + } + resp.Proto = proto + resp.Status = strings.TrimLeft(status, " ") + + statusCode, _, _ := strings.Cut(resp.Status, " ") + if len(statusCode) != 3 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + resp.StatusCode, err = strconv.Atoi(statusCode) + if err != nil || resp.StatusCode < 0 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + if resp.ProtoMajor, resp.ProtoMinor, ok = http.ParseHTTPVersion(resp.Proto); !ok { + return nil, badStringError("malformed HTTP version", resp.Proto) + } + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + resp.Header = http.Header(mimeHeader) + + fixPragmaCacheControl(resp.Header) + + //err = readTransfer(resp, r) + //if err != nil { + // return nil, err + //} + + return resp, nil +} + +// RFC 7234, section 5.4: Should treat +// +// Pragma: no-cache +// +// like +// +// Cache-Control: no-cache +func fixPragmaCacheControl(header http.Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} + +func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } diff --git a/libsq/core/ioz/httpcache/respcache.go b/libsq/core/ioz/httpcache/respcache.go index 464ebad18..6b964c2ed 100644 --- a/libsq/core/ioz/httpcache/respcache.go +++ b/libsq/core/ioz/httpcache/respcache.go @@ -40,7 +40,9 @@ type RespCache struct { Dir string } -func (rc *RespCache) getPaths(req *http.Request) (header, body string) { +// Paths returns the paths to the header and body files for req. +// It is not guaranteed that they exist. +func (rc *RespCache) Paths(req *http.Request) (header, body string) { if req == nil || req.Method == http.MethodGet { return filepath.Join(rc.Dir, "header"), filepath.Join(rc.Dir, "body") } @@ -49,13 +51,26 @@ func (rc *RespCache) getPaths(req *http.Request) (header, body string) { filepath.Join(rc.Dir, req.Method+"_body") } +// Exists returns true if the cache contains a response for req. +func (rc *RespCache) Exists(req *http.Request) bool { + rc.mu.Lock() + defer rc.mu.Unlock() + + fpHeader, _ := rc.Paths(req) + fi, err := os.Stat(fpHeader) + if err != nil { + return false + } + return fi.Size() > 0 +} + // Get returns the cached http.Response for req if present, and nil // otherwise. func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response, error) { rc.mu.Lock() defer rc.mu.Unlock() - fpHeader, fpBody := rc.getPaths(req) + fpHeader, fpBody := rc.Paths(req) if !ioz.FileAccessible(fpHeader) { return nil, nil @@ -76,7 +91,7 @@ func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response // We need to explicitly close bodyFile at some later point. It won't be // closed via a call to http.Response.Body.Close(). rc.clnup.AddC(bodyFile) - + // TODO: consider adding contextio.NewReader? concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) return http.ReadResponse(bufio.NewReader(concatRdr), req) } @@ -93,76 +108,103 @@ func (rc *RespCache) Close() error { } // Delete deletes the cache entries from disk. -func (rc *RespCache) Delete() error { +func (rc *RespCache) Delete(ctx context.Context) error { if rc == nil { return nil } rc.mu.Lock() defer rc.mu.Unlock() - return rc.doDelete() + return rc.doDelete(ctx) } -func (rc *RespCache) doDelete() error { +func (rc *RespCache) doDelete(ctx context.Context) error { cleanErr := rc.clnup.Run() rc.clnup = cleanup.New() - deleteErr := os.RemoveAll(rc.Dir) - return errz.Combine(cleanErr, deleteErr) + deleteErr := errz.Wrap(os.RemoveAll(rc.Dir), "delete cache dir") + err := errz.Combine(cleanErr, deleteErr) + if err != nil { + lg.FromContext(ctx).Error("Delete cache dir", + lga.Dir, rc.Dir, lga.Err, err) + return err + } + + lg.FromContext(ctx).Info("Deleted cache dir", lga.Dir, rc.Dir) + return nil } const msgDeleteCache = "Delete HTTP response cache" -// Write writes resp to the cache. -func (rc *RespCache) Write(ctx context.Context, resp *http.Response) error { +// Write writes resp to the cache. If copyWrtr is non-nil, the response +// bytes are copied to that destination also. +func (rc *RespCache) Write(ctx context.Context, resp *http.Response, copyWrtr io.WriteCloser) error { rc.mu.Lock() defer rc.mu.Unlock() - err := rc.doWrite(ctx, resp) + log := lg.FromContext(ctx) + log.Debug("huzzah in write") + + err := rc.doWrite(ctx, resp, copyWrtr) if err != nil { - lg.WarnIfFuncError(lg.FromContext(ctx), msgDeleteCache, rc.doDelete) + lg.WarnIfError(lg.FromContext(ctx), msgDeleteCache, rc.doDelete(ctx)) } return err } -func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response) error { +func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, copyWrtr io.WriteCloser) error { log := lg.FromContext(ctx) if err := ioz.RequireDir(rc.Dir); err != nil { return err } - fpHeader, fpBody := rc.getPaths(resp.Request) + fpHeader, fpBody := rc.Paths(resp.Request) - respBytes, err := httputil.DumpResponse(resp, false) + headerBytes, err := httputil.DumpResponse(resp, false) if err != nil { return err } - if _, err = ioz.WriteToFile(ctx, fpHeader, bytes.NewReader(respBytes)); err != nil { + if _, err = ioz.WriteToFile(ctx, fpHeader, bytes.NewReader(headerBytes)); err != nil { return err } - f, err := os.OpenFile(fpBody, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) + cacheFile, err := os.OpenFile(fpBody, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) if err != nil { return err } - cr := contextio.NewReader(ctx, resp.Body) - _, err = io.Copy(f, cr) + var cr io.Reader + if copyWrtr == nil { + cr = contextio.NewReader(ctx, resp.Body) + } else { + tr := io.TeeReader(resp.Body, copyWrtr) + cr = contextio.NewReader(ctx, tr) + } + + //if copyWrtr != nil { + // cr = io.TeeReader(cr, copyWrtr) + //} + var written int64 + written, err = io.Copy(cacheFile, cr) if err != nil { - lg.WarnIfCloseError(log, "Close cache body file", f) + lg.WarnIfCloseError(log, "Close cache body file", cacheFile) return err } + if copyWrtr != nil { + lg.WarnIfCloseError(log, "Close copy writer", copyWrtr) + } - if err = f.Close(); err != nil { + if err = cacheFile.Close(); err != nil { return err } - f, err = os.Open(fpBody) + log.Info("Wrote HTTP response to cache", lga.File, fpBody, lga.Size, written) + cacheFile, err = os.Open(fpBody) if err != nil { return err } - resp.Body = f + resp.Body = cacheFile return nil } diff --git a/libsq/core/ioz/httpcacheworking/LICENSE.txt b/libsq/core/ioz/httpcacheworking/LICENSE.txt new file mode 100644 index 000000000..81316beb0 --- /dev/null +++ b/libsq/core/ioz/httpcacheworking/LICENSE.txt @@ -0,0 +1,7 @@ +Copyright © 2012 Greg Jones (greg.jones@gmail.com) + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/libsq/core/ioz/httpcacheworking/README.md b/libsq/core/ioz/httpcacheworking/README.md new file mode 100644 index 000000000..58cda222f --- /dev/null +++ b/libsq/core/ioz/httpcacheworking/README.md @@ -0,0 +1,42 @@ +# httpcache + +[![GoDoc](https://godoc.org/github.com/bitcomplete/httpcache?status.svg)](https://godoc.org/github.com/bitcomplete/httpcache) + +Package httpcache provides a http.RoundTripper implementation that works as a +mostly [RFC 7234](https://tools.ietf.org/html/rfc7234) compliant cache for http +responses. This incarnation of the library is an active fork of +[github.com/gregjones/httpcache](https://github.com/gregjones/httpcache) which +is unmaintained. + +It is only suitable for use as a 'private' cache (i.e. for a web-browser or an +API-client and not for a shared proxy). + +## Cache Backends + +- The built-in 'memory' cache stores responses in an in-memory map. - + [`github.com/bitcomplete/httpcache/diskcache`](https://github.com/bitcomplete/httpcache/tree/master/diskcache) + provides a filesystem-backed cache using the + [diskv](https://github.com/peterbourgon/diskv) library. - + [`github.com/bitcomplete/httpcache/memcache`](https://github.com/bitcomplete/httpcache/tree/master/memcache) + provides memcache implementations, for both App Engine and 'normal' memcache + servers. - + [`sourcegraph.com/sourcegraph/s3cache`](https://sourcegraph.com/github.com/sourcegraph/s3cache) + uses Amazon S3 for storage. - + [`github.com/bitcomplete/httpcache/leveldbcache`](https://github.com/bitcomplete/httpcache/tree/master/leveldbcache) + provides a filesystem-backed cache using + [leveldb](https://github.com/syndtr/goleveldb/leveldb). - + [`github.com/die-net/lrucache`](https://github.com/die-net/lrucache) provides an + in-memory cache that will evict least-recently used entries. - + [`github.com/die-net/lrucache/twotier`](https://github.com/die-net/lrucache/tree/master/twotier) + allows caches to be combined, for example to use lrucache above with a + persistent disk-cache. - + [`github.com/birkelund/boltdbcache`](https://github.com/birkelund/boltdbcache) + provides a BoltDB implementation (based on the + [bbolt](https://github.com/coreos/bbolt) fork). + +If you implement any other backend and wish it to be linked here, please send a +PR editing this file. + +## License + +- [MIT License](LICENSE.txt) diff --git a/libsq/core/ioz/httpcacheworking/httpcache.go b/libsq/core/ioz/httpcacheworking/httpcache.go new file mode 100644 index 000000000..b8bbb9ef9 --- /dev/null +++ b/libsq/core/ioz/httpcacheworking/httpcache.go @@ -0,0 +1,717 @@ +// Package httpcacheworking provides a http.RoundTripper implementation that +// works as a mostly RFC-compliant cache for http responses. +// +// FIXME: move httpcache to internal/httpcache, because its use +// is so specialized? +// +// Acknowledgement: This package is a heavily customized fork +// of https://github.com/gregjones/httpcache, via bitcomplete/httpcache. +package httpcacheworking + +import ( + "bufio" + "bytes" + "errors" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "io" + "io/ioutil" + "net/http" + "os" + "strings" + "time" +) + +const ( + stale = iota + fresh + transparent + // XFromCache is the header added to responses that are returned from the cache + XFromCache = "X-From-Cache" +) + +// TransportOpt is a configuration option for creating a new Transport +type TransportOpt func(t *Transport) + +// MarkCachedResponsesOpt configures a transport by setting MarkCachedResponses to true +func MarkCachedResponsesOpt(markCachedResponses bool) TransportOpt { + return func(t *Transport) { + t.MarkCachedResponses = markCachedResponses + } +} + +// Transport is an implementation of http.RoundTripper that will return values from a cache +// where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) +// to repeated requests allowing servers to return 304 / Not Modified +type Transport struct { + // The RoundTripper interface actually used to make requests + // If nil, http.DefaultTransport is used + Transport http.RoundTripper + + RespCache *RespCache + + // MarkCachedResponses, if true, indicates that responses returned from the + // cache will be given an extra header, X-From-Cache + MarkCachedResponses bool +} + +// NewTransport returns a new Transport with the provided Cache and options. If +// KeyFunc is not specified in opts then DefaultKeyFunc is used. +func NewTransport(rc *RespCache, opts ...TransportOpt) *Transport { + t := &Transport{ + RespCache: rc, + MarkCachedResponses: true, + } + for _, opt := range opts { + opt(t) + } + return t +} + +// Client returns an *http.Client that caches responses. +func (t *Transport) Client() *http.Client { + return &http.Client{Transport: t} +} + +// varyMatches will return false unless all the cached values for the +// headers listed in Vary match the new request +func varyMatches(cachedResp *http.Response, req *http.Request) bool { + for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { + header = http.CanonicalHeaderKey(header) + if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) { + return false + } + } + return true +} + +// IsCached returns true if there is a cache entry for req. This does not +// guarantee that the cache entry is fresh. +func (t *Transport) IsCached(req *http.Request) bool { + return t.RespCache.Exists(req) +} + +// IsFresh returns true if there is a fresh cache entry for req. +func (t *Transport) IsFresh(req *http.Request) bool { + ctx := req.Context() + log := lg.FromContext(ctx) + + if !isCacheable(req) { + return false + } + + if !t.RespCache.Exists(req) { + return false + } + + fpHeader, _ := t.RespCache.Paths(req) + f, err := os.Open(fpHeader) + if err != nil { + log.Error("Failed to open cached response header file", lga.File, fpHeader, lga.Err, err) + return false + } + + defer lg.WarnIfCloseError(log, "Close cached response header", f) + + cachedResp, err := ReadResponse(bufio.NewReader(f), nil, true) + if err != nil { + log.Error("Failed to read cached response", lga.Err, err) + return false + } + + freshness := getFreshness(cachedResp.Header, req.Header) + return freshness == fresh +} + +func isCacheable(req *http.Request) bool { + return (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get("range") == "" +} + +type CallbackHandler struct { + HandleCached func(cachedFilepath string) error + HandleUncached func() (wc io.WriteCloser, errFn func(error), err error) + HandleError func(err error) +} + +func (t *Transport) Fetch(req *http.Request, cb CallbackHandler) { + ctx := req.Context() + log := lg.FromContext(ctx) + log.Info("Fetching download", lga.URL, req.URL.String()) + _ = log + _, fpBody := t.RespCache.Paths(req) + + if t.IsFresh(req) { + _ = cb.HandleCached(fpBody) + return + } + + var err error + cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" + var cachedResp *http.Response + if cacheable { + cachedResp, err = t.RespCache.Get(req.Context(), req) + } else { + // Need to invalidate an existing value + if err = t.RespCache.Delete(req.Context()); err != nil { + cb.HandleError(err) + return + } + } + + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + var resp *http.Response + if cacheable && cachedResp != nil && err == nil { + if t.MarkCachedResponses { + cachedResp.Header.Set(XFromCache, "1") + } + + if varyMatches(cachedResp, req) { + // Can only use cached value if the new request doesn't Vary significantly + freshness := getFreshness(cachedResp.Header, req.Header) + if freshness == fresh { + _ = cb.HandleCached(fpBody) + return + } + + if freshness == stale { + var req2 *http.Request + // Add validators if caller hasn't already done so + etag := cachedResp.Header.Get("etag") + if etag != "" && req.Header.Get("etag") == "" { + req2 = cloneRequest(req) + req2.Header.Set("if-none-match", etag) + } + lastModified := cachedResp.Header.Get("last-modified") + if lastModified != "" && req.Header.Get("last-modified") == "" { + if req2 == nil { + req2 = cloneRequest(req) + } + req2.Header.Set("if-modified-since", lastModified) + } + if req2 != nil { + req = req2 + } + } + } + + // FIXME: Use an http client here + resp, err = transport.RoundTrip(req) + if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { + // Replace the 304 response with the one from cache, but update with some new headers + endToEndHeaders := getEndToEndHeaders(resp.Header) + for _, header := range endToEndHeaders { + cachedResp.Header[header] = resp.Header[header] + } + resp = cachedResp + } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && + req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header) { + // In case of transport failure and stale-if-error activated, returns cached content + // when available + log.Warn("Returning cached response due to transport failure", lga.Err, err) + cb.HandleCached(fpBody) + return + } else { + if err != nil || resp.StatusCode != http.StatusOK { + lg.WarnIfError(log, msgDeleteCache, t.RespCache.Delete(req.Context())) + } + if err != nil { + cb.HandleError(err) + return + } + } + } else { + reqCacheControl := parseCacheControl(req.Header) + if _, ok := reqCacheControl["only-if-cached"]; ok { + resp = newGatewayTimeoutResponse(req) + } else { + resp, err = transport.RoundTrip(req) + if err != nil { + cb.HandleError(err) + return + } + } + } + + if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { + for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { + varyKey = http.CanonicalHeaderKey(varyKey) + fakeHeader := "X-Varied-" + varyKey + reqValue := req.Header.Get(varyKey) + if reqValue != "" { + resp.Header.Set(fakeHeader, reqValue) + } + } + switch req.Method { + //case "GET": + // // Delay caching until EOF is reached. + // resp.Body = &cachingReadCloser{ + // R: resp.Body, + // OnEOF: func(r io.Reader) { + // resp := *resp + // resp.Body = ioutil.NopCloser(r) + // if err := t.RespCache.Write(req.Context(), &resp, nil); err != nil { + // log.Error("failed to write download cache", lga.Err, err) + // } + // }, + // } + default: + copyWrtr, errFn, err := cb.HandleUncached() + if err != nil { + cb.HandleError(err) + return + } + + if err = t.RespCache.Write(req.Context(), resp, copyWrtr); err != nil { + log.Error("failed to write download cache", lga.Err, err) + errFn(err) + cb.HandleError(err) + } + return + } + } else { + lg.WarnIfError(log, "Delete resp cache", t.RespCache.Delete(req.Context())) + } + + // It's not cacheable, so we need to write it to the copyWrtr + copyWrtr, errFn, err := cb.HandleUncached() + if err != nil { + cb.HandleError(err) + return + } + cr := contextio.NewReader(ctx, resp.Body) + _, err = io.Copy(copyWrtr, cr) + if err != nil { + errFn(err) + cb.HandleError(err) + return + } + if err = copyWrtr.Close(); err != nil { + cb.HandleError(err) + return + } + + return +} + +// RoundTrip takes a Request and returns a Response +// +// If there is a fresh Response already in cache, then it will be returned without connecting to +// the server. +// +// If there is a stale Response, then any validators it contains will be set on the new request +// to give the server a chance to respond with NotModified. If this happens, then the cached Response +// will be returned. +func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + log := lg.FromContext(req.Context()) + + cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" + var cachedResp *http.Response + if cacheable { + cachedResp, err = t.RespCache.Get(req.Context(), req) + } else { + // Need to invalidate an existing value + if err = t.RespCache.Delete(req.Context()); err != nil { + return nil, err + } + } + + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + + if cacheable && cachedResp != nil && err == nil { + if t.MarkCachedResponses { + cachedResp.Header.Set(XFromCache, "1") + } + + if varyMatches(cachedResp, req) { + // Can only use cached value if the new request doesn't Vary significantly + freshness := getFreshness(cachedResp.Header, req.Header) + if freshness == fresh { + return cachedResp, nil + } + + if freshness == stale { + var req2 *http.Request + // Add validators if caller hasn't already done so + etag := cachedResp.Header.Get("etag") + if etag != "" && req.Header.Get("etag") == "" { + req2 = cloneRequest(req) + req2.Header.Set("if-none-match", etag) + } + lastModified := cachedResp.Header.Get("last-modified") + if lastModified != "" && req.Header.Get("last-modified") == "" { + if req2 == nil { + req2 = cloneRequest(req) + } + req2.Header.Set("if-modified-since", lastModified) + } + if req2 != nil { + req = req2 + } + } + } + + resp, err = transport.RoundTrip(req) + if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified { + // Replace the 304 response with the one from cache, but update with some new headers + endToEndHeaders := getEndToEndHeaders(resp.Header) + for _, header := range endToEndHeaders { + cachedResp.Header[header] = resp.Header[header] + } + resp = cachedResp + } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && + req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) { + // In case of transport failure and stale-if-error activated, returns cached content + // when available + return cachedResp, nil + } else { + if err != nil || resp.StatusCode != http.StatusOK { + lg.WarnIfError(log, msgDeleteCache, t.RespCache.Delete(req.Context())) + } + if err != nil { + return nil, err + } + } + } else { + reqCacheControl := parseCacheControl(req.Header) + if _, ok := reqCacheControl["only-if-cached"]; ok { + resp = newGatewayTimeoutResponse(req) + } else { + resp, err = transport.RoundTrip(req) + if err != nil { + return nil, err + } + } + } + + if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { + for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { + varyKey = http.CanonicalHeaderKey(varyKey) + fakeHeader := "X-Varied-" + varyKey + reqValue := req.Header.Get(varyKey) + if reqValue != "" { + resp.Header.Set(fakeHeader, reqValue) + } + } + switch req.Method { + case "GET": + // Delay caching until EOF is reached. + resp.Body = &cachingReadCloser{ + R: resp.Body, + OnEOF: func(r io.Reader) { + resp := *resp + resp.Body = ioutil.NopCloser(r) + if err := t.RespCache.Write(req.Context(), &resp, nil); err != nil { + log.Error("failed to write download cache", lga.Err, err) + } + }, + } + default: + if err = t.RespCache.Write(req.Context(), resp, nil); err != nil { + log.Error("failed to write download cache", lga.Err, err) + } + } + } else { + lg.WarnIfError(log, "Delete resp cache", t.RespCache.Delete(req.Context())) + } + return resp, nil +} + +// ErrNoDateHeader indicates that the HTTP headers contained no Date header. +var ErrNoDateHeader = errors.New("no Date header") + +// Date parses and returns the value of the Date header. +func Date(respHeaders http.Header) (date time.Time, err error) { + dateHeader := respHeaders.Get("date") + if dateHeader == "" { + err = ErrNoDateHeader + return + } + + return time.Parse(time.RFC1123, dateHeader) +} + +type realClock struct{} + +func (c *realClock) since(d time.Time) time.Duration { + return time.Since(d) +} + +type timer interface { + since(d time.Time) time.Duration +} + +var clock timer = &realClock{} + +// getFreshness will return one of fresh/stale/transparent based on the cache-control +// values of the request and the response +// +// fresh indicates the response can be returned +// stale indicates that the response needs validating before it is returned +// transparent indicates the response should not be used to fulfil the request +// +// Because this is only a private cache, 'public' and 'private' in cache-control aren't +// signficant. Similarly, smax-age isn't used. +func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { + respCacheControl := parseCacheControl(respHeaders) + reqCacheControl := parseCacheControl(reqHeaders) + if _, ok := reqCacheControl["no-cache"]; ok { + return transparent + } + if _, ok := respCacheControl["no-cache"]; ok { + return stale + } + if _, ok := reqCacheControl["only-if-cached"]; ok { + return fresh + } + + date, err := Date(respHeaders) + if err != nil { + return stale + } + currentAge := clock.since(date) + + var lifetime time.Duration + var zeroDuration time.Duration + + // If a response includes both an Expires header and a max-age directive, + // the max-age directive overrides the Expires header, even if the Expires header is more restrictive. + if maxAge, ok := respCacheControl["max-age"]; ok { + lifetime, err = time.ParseDuration(maxAge + "s") + if err != nil { + lifetime = zeroDuration + } + } else { + expiresHeader := respHeaders.Get("Expires") + if expiresHeader != "" { + expires, err := time.Parse(time.RFC1123, expiresHeader) + if err != nil { + lifetime = zeroDuration + } else { + lifetime = expires.Sub(date) + } + } + } + + if maxAge, ok := reqCacheControl["max-age"]; ok { + // the client is willing to accept a response whose age is no greater than the specified time in seconds + lifetime, err = time.ParseDuration(maxAge + "s") + if err != nil { + lifetime = zeroDuration + } + } + if minfresh, ok := reqCacheControl["min-fresh"]; ok { + // the client wants a response that will still be fresh for at least the specified number of seconds. + minfreshDuration, err := time.ParseDuration(minfresh + "s") + if err == nil { + currentAge = time.Duration(currentAge + minfreshDuration) + } + } + + if maxstale, ok := reqCacheControl["max-stale"]; ok { + // Indicates that the client is willing to accept a response that has exceeded its expiration time. + // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded + // its expiration time by no more than the specified number of seconds. + // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age. + // + // Responses served only because of a max-stale value are supposed to have a Warning header added to them, + // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different + // return-value available here. + if maxstale == "" { + return fresh + } + maxstaleDuration, err := time.ParseDuration(maxstale + "s") + if err == nil { + currentAge = time.Duration(currentAge - maxstaleDuration) + } + } + + if lifetime > currentAge { + return fresh + } + + return stale +} + +// Returns true if either the request or the response includes the stale-if-error +// cache control extension: https://tools.ietf.org/html/rfc5861 +func canStaleOnError(respHeaders, reqHeaders http.Header) bool { + respCacheControl := parseCacheControl(respHeaders) + reqCacheControl := parseCacheControl(reqHeaders) + + var err error + lifetime := time.Duration(-1) + + if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok { + if staleMaxAge != "" { + lifetime, err = time.ParseDuration(staleMaxAge + "s") + if err != nil { + return false + } + } else { + return true + } + } + if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok { + if staleMaxAge != "" { + lifetime, err = time.ParseDuration(staleMaxAge + "s") + if err != nil { + return false + } + } else { + return true + } + } + + if lifetime >= 0 { + date, err := Date(respHeaders) + if err != nil { + return false + } + currentAge := clock.since(date) + if lifetime > currentAge { + return true + } + } + + return false +} + +func getEndToEndHeaders(respHeaders http.Header) []string { + // These headers are always hop-by-hop + hopByHopHeaders := map[string]struct{}{ + "Connection": {}, + "Keep-Alive": {}, + "Proxy-Authenticate": {}, + "Proxy-Authorization": {}, + "Te": {}, + "Trailers": {}, + "Transfer-Encoding": {}, + "Upgrade": {}, + } + + for _, extra := range strings.Split(respHeaders.Get("connection"), ",") { + // any header listed in connection, if present, is also considered hop-by-hop + if strings.Trim(extra, " ") != "" { + hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{} + } + } + endToEndHeaders := []string{} + for respHeader := range respHeaders { + if _, ok := hopByHopHeaders[respHeader]; !ok { + endToEndHeaders = append(endToEndHeaders, respHeader) + } + } + return endToEndHeaders +} + +func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) { + if _, ok := respCacheControl["no-store"]; ok { + return false + } + if _, ok := reqCacheControl["no-store"]; ok { + return false + } + return true +} + +func newGatewayTimeoutResponse(req *http.Request) *http.Response { + var braw bytes.Buffer + braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n") + resp, err := http.ReadResponse(bufio.NewReader(&braw), req) + if err != nil { + panic(err) + } + return resp +} + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +// (This function copyright goauth2 authors: https://code.google.com/p/goauth2) +func cloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + if ctx := r.Context(); ctx != nil { + r2 = r2.WithContext(ctx) + } + // deep copy of the Header + r2.Header = make(http.Header) + for k, s := range r.Header { + r2.Header[k] = s + } + return r2 +} + +type cacheControl map[string]string + +func parseCacheControl(headers http.Header) cacheControl { + cc := cacheControl{} + ccHeader := headers.Get("Cache-Control") + for _, part := range strings.Split(ccHeader, ",") { + part = strings.Trim(part, " ") + if part == "" { + continue + } + if strings.ContainsRune(part, '=') { + keyval := strings.Split(part, "=") + cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",") + } else { + cc[part] = "" + } + } + return cc +} + +// headerAllCommaSepValues returns all comma-separated values (each +// with whitespace trimmed) for header name in headers. According to +// Section 4.2 of the HTTP/1.1 spec +// (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2), +// values from multiple occurrences of a header should be concatenated, if +// the header's value is a comma-separated list. +func headerAllCommaSepValues(headers http.Header, name string) []string { + var vals []string + for _, val := range headers[http.CanonicalHeaderKey(name)] { + fields := strings.Split(val, ",") + for i, f := range fields { + fields[i] = strings.TrimSpace(f) + } + vals = append(vals, fields...) + } + return vals +} + +// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF +// handler with a full copy of the content read from R when EOF is +// reached. +type cachingReadCloser struct { + // Underlying ReadCloser. + R io.ReadCloser + // OnEOF is called with a copy of the content of R when EOF is reached. + OnEOF func(io.Reader) + + buf bytes.Buffer // buf stores a copy of the content of R. +} + +// Read reads the next len(p) bytes from R or until R is drained. The +// return value n is the number of bytes read. If R has no data to +// return, err is io.EOF and OnEOF is called with a full copy of what +// has been read so far. +func (r *cachingReadCloser) Read(p []byte) (n int, err error) { + n, err = r.R.Read(p) + r.buf.Write(p[:n]) + if err == io.EOF { + r.OnEOF(bytes.NewReader(r.buf.Bytes())) + } + return n, err +} + +func (r *cachingReadCloser) Close() error { + return r.R.Close() +} diff --git a/libsq/core/ioz/httpcacheworking/httpcache_test.go b/libsq/core/ioz/httpcacheworking/httpcache_test.go new file mode 100644 index 000000000..51a245237 --- /dev/null +++ b/libsq/core/ioz/httpcacheworking/httpcache_test.go @@ -0,0 +1,1486 @@ +package httpcacheworking + +import ( + "bytes" + "errors" + "flag" + "github.com/neilotoole/sq/libsq/core/stringz" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "testing" + "time" +) + +// newTestTransport returns a new Transport using the in-memory cache implementation +func newTestTransport(cacheDir string, opts ...TransportOpt) *Transport { + rc := NewRespCache(cacheDir) + t := NewTransport(rc, opts...) + return t +} + +var s struct { + server *httptest.Server + client http.Client + transport *Transport + done chan struct{} // Closed to unlock infinite handlers. +} + +type fakeClock struct { + elapsed time.Duration +} + +func (c *fakeClock) since(t time.Time) time.Duration { + return c.elapsed +} + +func TestMain(m *testing.M) { + flag.Parse() + setup() + code := m.Run() + teardown() + os.Exit(code) +} + +func setup() { + tp := newTestTransport(filepath.Join(os.TempDir(), stringz.Uniq8())) + client := http.Client{Transport: tp} + s.transport = tp + s.client = client + s.done = make(chan struct{}) + + mux := http.NewServeMux() + s.server = httptest.NewServer(mux) + + mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + })) + + mux.HandleFunc("/method", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + _, _ = w.Write([]byte(r.Method)) + })) + + mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lm := "Fri, 14 Dec 2010 01:01:50 GMT" + if r.Header.Get("if-modified-since") == lm { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("last-modified", lm) + if r.Header.Get("range") == "bytes=4-9" { + w.WriteHeader(http.StatusPartialContent) + _, _ = w.Write([]byte(" text ")) + return + } + _, _ = w.Write([]byte("Some text content")) + })) + + mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "no-store") + })) + + mux.HandleFunc("/etag", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + etag := "124567" + if r.Header.Get("if-none-match") == etag { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("etag", etag) + })) + + mux.HandleFunc("/lastmodified", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lm := "Fri, 14 Dec 2010 01:01:50 GMT" + if r.Header.Get("if-modified-since") == lm { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("last-modified", lm) + })) + + mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Vary", "Accept") + _, _ = w.Write([]byte("Some text content")) + })) + + mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Vary", "Accept, Accept-Language") + _, _ = w.Write([]byte("Some text content")) + })) + mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Add("Vary", "Accept") + w.Header().Add("Vary", "Accept-Language") + _, _ = w.Write([]byte("Some text content")) + })) + mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Cache-Control", "max-age=3600") + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Vary", "X-Madeup-Header") + _, _ = w.Write([]byte("Some text content")) + })) + + mux.HandleFunc("/cachederror", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + etag := "abc" + if r.Header.Get("if-none-match") == etag { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("etag", etag) + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("Not found")) + })) + + updateFieldsCounter := 0 + mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Counter", strconv.Itoa(updateFieldsCounter)) + w.Header().Set("Etag", `"e"`) + updateFieldsCounter++ + if r.Header.Get("if-none-match") != "" { + w.WriteHeader(http.StatusNotModified) + return + } + _, _ = w.Write([]byte("Some text content")) + })) + + // Take 3 seconds to return 200 OK (for testing client timeouts). + mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(3 * time.Second) + })) + + mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for { + select { + case <-s.done: + return + default: + _, _ = w.Write([]byte{0}) + } + } + })) +} + +func teardown() { + close(s.done) + s.server.Close() +} + +func resetTest(t testing.TB) { + s.transport.RespCache = NewRespCache(t.TempDir()) + //s.transport.RespCache.Delete() + clock = &realClock{} +} + +// TestCacheableMethod ensures that uncacheable method does not get stored +// in cache and get incorrectly used for a following cacheable method request. +func TestCacheableMethod(t *testing.T) { + resetTest(t) + { + req, err := http.NewRequest("POST", s.server.URL+"/method", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "POST"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/method", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "GET"; got != want { + t.Errorf("got wrong body %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "" { + t.Errorf("XFromCache header isn't blank") + } + } +} + +func TestDontServeHeadResponseToGetRequest(t *testing.T) { + resetTest(t) + url := s.server.URL + "/" + req, err := http.NewRequest(http.MethodHead, url, nil) + if err != nil { + t.Fatal(err) + } + _, err = s.client.Do(req) + if err != nil { + t.Fatal(err) + } + req, err = http.NewRequest(http.MethodGet, url, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.Header.Get(XFromCache) != "" { + t.Errorf("Cache should not match") + } +} + +func TestDontStorePartialRangeInCache(t *testing.T) { + resetTest(t) + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("range", "bytes=4-9") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), " text "; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusPartialContent { + t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "Some text content"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "" { + t.Error("XFromCache header isn't blank") + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), "Some text content"; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusOK { + t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if resp.Header.Get(XFromCache) != "1" { + t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + { + req, err := http.NewRequest("GET", s.server.URL+"/range", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("range", "bytes=4-9") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatal(err) + } + err = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if got, want := buf.String(), " text "; got != want { + t.Errorf("got %q, want %q", got, want) + } + if resp.StatusCode != http.StatusPartialContent { + t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) + } + } +} + +func TestCacheOnlyIfBodyRead(t *testing.T) { + resetTest(t) + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + // We do not read the body + resp.Body.Close() + } + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatalf("XFromCache header isn't blank") + } + } +} + +func TestOnlyReadBodyOnDemand(t *testing.T) { + resetTest(t) + + req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) // This shouldn't hang forever. + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 10) // Only partially read the body. + _, err = resp.Body.Read(buf) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() +} + +func TestGetOnlyIfCachedHit(t *testing.T) { + resetTest(t) + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("cache-control", "only-if-cached") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) + } + } +} + +func TestGetOnlyIfCachedMiss(t *testing.T) { + resetTest(t) + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("cache-control", "only-if-cached") + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + if resp.StatusCode != http.StatusGatewayTimeout { + t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) + } +} + +func TestGetNoStoreRequest(t *testing.T) { + resetTest(t) + req, err := http.NewRequest("GET", s.server.URL, nil) + if err != nil { + t.Fatal(err) + } + req.Header.Add("Cache-Control", "no-store") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetNoStoreResponse(t *testing.T) { + resetTest(t) + req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetWithEtag(t *testing.T) { + resetTest(t) + req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + // additional assertions to verify that 304 response is converted properly + if resp.StatusCode != http.StatusOK { + t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) + } + if _, ok := resp.Header["Connection"]; ok { + t.Fatalf("Connection header isn't absent") + } + } +} + +func TestGetWithLastModified(t *testing.T) { + resetTest(t) + req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } +} + +func TestGetWithVary(t *testing.T) { + resetTest(t) + req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") != "Accept" { + t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary")) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + req.Header.Set("Accept", "text/html") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetWithDoubleVary(t *testing.T) { + resetTest(t) + req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") == "" { + t.Fatalf(`Vary header is blank`) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + req.Header.Set("Accept-Language", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept-Language", "da") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } +} + +func TestGetWith2VaryHeaders(t *testing.T) { + resetTest(t) + // Tests that multiple Vary headers' comma-separated lists are + // merged. See https://github.com/gregjones/httpcache/issues/27. + const ( + accept = "text/plain" + acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" + ) + req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", accept) + req.Header.Set("Accept-Language", acceptLanguage) + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") == "" { + t.Fatalf(`Vary header is blank`) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } + req.Header.Set("Accept-Language", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept-Language", "da") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept-Language", acceptLanguage) + req.Header.Set("Accept", "") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + } + req.Header.Set("Accept", "image/png") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "" { + t.Fatal("XFromCache header isn't blank") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } +} + +func TestGetVaryUnused(t *testing.T) { + resetTest(t) + req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "text/plain") + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get("Vary") == "" { + t.Fatalf(`Vary header is blank`) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + } +} + +func TestUpdateFields(t *testing.T) { + resetTest(t) + req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) + if err != nil { + t.Fatal(err) + } + var counter, counter2 string + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + counter = resp.Header.Get("x-counter") + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.Header.Get(XFromCache) != "1" { + t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) + } + counter2 = resp.Header.Get("x-counter") + } + if counter == counter2 { + t.Fatalf(`both "x-counter" values are equal: %v %v`, counter, counter2) + } +} + +// This tests the fix for https://github.com/gregjones/httpcache/issues/74. +// Previously, after validating a cached response, its StatusCode +// was incorrectly being replaced. +func TestCachedErrorsKeepStatus(t *testing.T) { + resetTest(t) + req, err := http.NewRequest("GET", s.server.URL+"/cachederror", nil) + if err != nil { + t.Fatal(err) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + _, _ = io.Copy(ioutil.Discard, resp.Body) + } + { + resp, err := s.client.Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("Status code isn't 404: %d", resp.StatusCode) + } + } +} + +func TestParseCacheControl(t *testing.T) { + resetTest(t) + h := http.Header{} + for range parseCacheControl(h) { + t.Fatal("cacheControl should be empty") + } + + h.Set("cache-control", "no-cache") + { + cc := parseCacheControl(h) + if _, ok := cc["foo"]; ok { + t.Error(`Value "foo" shouldn't exist`) + } + noCache, ok := cc["no-cache"] + if !ok { + t.Fatalf(`"no-cache" value isn't set`) + } + if noCache != "" { + t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) + } + } + h.Set("cache-control", "no-cache, max-age=3600") + { + cc := parseCacheControl(h) + noCache, ok := cc["no-cache"] + if !ok { + t.Fatalf(`"no-cache" value isn't set`) + } + if noCache != "" { + t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) + } + if cc["max-age"] != "3600" { + t.Fatalf(`"max-age" value isn't "3600": %v`, cc["max-age"]) + } + } +} + +func TestNoCacheRequestExpiration(t *testing.T) { + resetTest(t) + respHeaders := http.Header{} + respHeaders.Set("Cache-Control", "max-age=7200") + + reqHeaders := http.Header{} + reqHeaders.Set("Cache-Control", "no-cache") + if getFreshness(respHeaders, reqHeaders) != transparent { + t.Fatal("freshness isn't transparent") + } +} + +func TestNoCacheResponseExpiration(t *testing.T) { + resetTest(t) + respHeaders := http.Header{} + respHeaders.Set("Cache-Control", "no-cache") + respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestReqMustRevalidate(t *testing.T) { + resetTest(t) + // not paying attention to request setting max-stale means never returning stale + // responses, so always acting as if must-revalidate is set + respHeaders := http.Header{} + + reqHeaders := http.Header{} + reqHeaders.Set("Cache-Control", "must-revalidate") + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestRespMustRevalidate(t *testing.T) { + resetTest(t) + respHeaders := http.Header{} + respHeaders.Set("Cache-Control", "must-revalidate") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestFreshExpiration(t *testing.T) { + resetTest(t) + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 3 * time.Second} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestMaxAge(t *testing.T) { + resetTest(t) + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=2") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 3 * time.Second} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestMaxAgeZero(t *testing.T) { + resetTest(t) + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=0") + + reqHeaders := http.Header{} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestBothMaxAge(t *testing.T) { + resetTest(t) + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=2") + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "max-age=0") + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestMinFreshWithExpires(t *testing.T) { + resetTest(t) + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "min-fresh=1") + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + reqHeaders = http.Header{} + reqHeaders.Set("cache-control", "min-fresh=2") + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func TestEmptyMaxStale(t *testing.T) { + resetTest(t) + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=20") + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "max-stale") + clock = &fakeClock{elapsed: 10 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 60 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } +} + +func TestMaxStaleValue(t *testing.T) { + resetTest(t) + now := time.Now() + respHeaders := http.Header{} + respHeaders.Set("date", now.Format(time.RFC1123)) + respHeaders.Set("cache-control", "max-age=10") + + reqHeaders := http.Header{} + reqHeaders.Set("cache-control", "max-stale=20") + clock = &fakeClock{elapsed: 5 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 15 * time.Second} + if getFreshness(respHeaders, reqHeaders) != fresh { + t.Fatal("freshness isn't fresh") + } + + clock = &fakeClock{elapsed: 30 * time.Second} + if getFreshness(respHeaders, reqHeaders) != stale { + t.Fatal("freshness isn't stale") + } +} + +func containsHeader(headers []string, header string) bool { + for _, v := range headers { + if http.CanonicalHeaderKey(v) == http.CanonicalHeaderKey(header) { + return true + } + } + return false +} + +func TestGetEndToEndHeaders(t *testing.T) { + resetTest(t) + var ( + headers http.Header + end2end []string + ) + + headers = http.Header{} + headers.Set("content-type", "text/html") + headers.Set("te", "deflate") + + end2end = getEndToEndHeaders(headers) + if !containsHeader(end2end, "content-type") { + t.Fatal(`doesn't contain "content-type" header`) + } + if containsHeader(end2end, "te") { + t.Fatal(`doesn't contain "te" header`) + } + + headers = http.Header{} + headers.Set("connection", "content-type") + headers.Set("content-type", "text/csv") + headers.Set("te", "deflate") + end2end = getEndToEndHeaders(headers) + if containsHeader(end2end, "connection") { + t.Fatal(`doesn't contain "connection" header`) + } + if containsHeader(end2end, "content-type") { + t.Fatal(`doesn't contain "content-type" header`) + } + if containsHeader(end2end, "te") { + t.Fatal(`doesn't contain "te" header`) + } + + headers = http.Header{} + end2end = getEndToEndHeaders(headers) + if len(end2end) != 0 { + t.Fatal(`non-zero end2end headers`) + } + + headers = http.Header{} + headers.Set("connection", "content-type") + end2end = getEndToEndHeaders(headers) + if len(end2end) != 0 { + t.Fatal(`non-zero end2end headers`) + } +} + +type transportMock struct { + response *http.Response + err error +} + +func (t transportMock) RoundTrip(req *http.Request) (resp *http.Response, err error) { + return t.response, t.err +} + +func TestStaleIfErrorRequest(t *testing.T) { + resetTest(t) + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := newTestTransport(t.TempDir()) + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + r.Header.Set("Cache-Control", "stale-if-error") + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } +} + +func TestStaleIfErrorRequestLifetime(t *testing.T) { + resetTest(t) + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := newTestTransport(t.TempDir()) + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + r.Header.Set("Cache-Control", "stale-if-error=100") + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + + // Same for http errors + tmock.response = &http.Response{StatusCode: http.StatusInternalServerError} + tmock.err = nil + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + + // If failure last more than max stale, error is returned + clock = &fakeClock{elapsed: 200 * time.Second} + _, err = tp.RoundTrip(r) + if err != tmock.err { + t.Fatalf("got err %v, want %v", err, tmock.err) + } +} + +func TestStaleIfErrorResponse(t *testing.T) { + resetTest(t) + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache, stale-if-error"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := newTestTransport(t.TempDir()) + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } +} + +func TestStaleIfErrorResponseLifetime(t *testing.T) { + resetTest(t) + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache, stale-if-error=100"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := newTestTransport(t.TempDir()) + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + + // If failure last more than max stale, error is returned + clock = &fakeClock{elapsed: 200 * time.Second} + _, err = tp.RoundTrip(r) + if err != tmock.err { + t.Fatalf("got err %v, want %v", err, tmock.err) + } +} + +// This tests the fix for https://github.com/gregjones/httpcache/issues/74. +// Previously, after a stale response was used after encountering an error, +// its StatusCode was being incorrectly replaced. +func TestStaleIfErrorKeepsStatus(t *testing.T) { + resetTest(t) + now := time.Now() + tmock := transportMock{ + response: &http.Response{ + Status: http.StatusText(http.StatusNotFound), + StatusCode: http.StatusNotFound, + Header: http.Header{ + "Date": []string{now.Format(time.RFC1123)}, + "Cache-Control": []string{"no-cache"}, + }, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), + }, + err: nil, + } + tp := newTestTransport(t.TempDir()) + tp.Transport = &tmock + + // First time, response is cached on success + r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) + r.Header.Set("Cache-Control", "stale-if-error") + resp, err := tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + // On failure, response is returned from the cache + tmock.response = nil + tmock.err = errors.New("some error") + resp, err = tp.RoundTrip(r) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("resp is nil") + } + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("Status wasn't 404: %d", resp.StatusCode) + } +} + +// Test that http.Client.Timeout is respected when cache transport is used. +// That is so as long as request cancellation is propagated correctly. +// In the past, that required CancelRequest to be implemented correctly, +// but modern http.Client uses Request.Cancel (or request context) instead, +// so we don't have to do anything. +func TestClientTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping timeout test in short mode") // Because it takes at least 3 seconds to run. + } + resetTest(t) + + client := &http.Client{ + Transport: newTestTransport(t.TempDir()), + Timeout: time.Second, + } + started := time.Now() + resp, err := client.Get(s.server.URL + "/3seconds") + taken := time.Since(started) + if err == nil { + t.Error("got nil error, want timeout error") + } + if resp != nil { + t.Error("got non-nil resp, want nil resp") + } + if taken >= 2*time.Second { + t.Error("client.Do took 2+ seconds, want < 2 seconds") + } +} diff --git a/libsq/core/ioz/httpcacheworking/httpz.go b/libsq/core/ioz/httpcacheworking/httpz.go new file mode 100644 index 000000000..526d2565c --- /dev/null +++ b/libsq/core/ioz/httpcacheworking/httpz.go @@ -0,0 +1,88 @@ +package httpcacheworking + +import ( + "bufio" + "fmt" + "io" + "net/http" + "net/textproto" + "strconv" + "strings" +) + +// ReadResponse is a copy of http.ReadResponse, but with the option +// to read only the response header, and not the body. When only reading +// the header, note that resp.Body will be nil, and that the resp is +// generally not functional. +func ReadResponse(r *bufio.Reader, req *http.Request, headerOnly bool) (*http.Response, error) { + if !headerOnly { + return http.ReadResponse(r, req) + } + tp := textproto.NewReader(r) + resp := &http.Response{ + Request: req, + } + + // Parse the first line of the response. + line, err := tp.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + proto, status, ok := strings.Cut(line, " ") + if !ok { + return nil, badStringError("malformed HTTP response", line) + } + resp.Proto = proto + resp.Status = strings.TrimLeft(status, " ") + + statusCode, _, _ := strings.Cut(resp.Status, " ") + if len(statusCode) != 3 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + resp.StatusCode, err = strconv.Atoi(statusCode) + if err != nil || resp.StatusCode < 0 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + if resp.ProtoMajor, resp.ProtoMinor, ok = http.ParseHTTPVersion(resp.Proto); !ok { + return nil, badStringError("malformed HTTP version", resp.Proto) + } + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + resp.Header = http.Header(mimeHeader) + + fixPragmaCacheControl(resp.Header) + + //err = readTransfer(resp, r) + //if err != nil { + // return nil, err + //} + + return resp, nil +} + +// RFC 7234, section 5.4: Should treat +// +// Pragma: no-cache +// +// like +// +// Cache-Control: no-cache +func fixPragmaCacheControl(header http.Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} + +func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } diff --git a/libsq/core/ioz/httpcacheworking/respcache.go b/libsq/core/ioz/httpcacheworking/respcache.go new file mode 100644 index 000000000..e4c5c830a --- /dev/null +++ b/libsq/core/ioz/httpcacheworking/respcache.go @@ -0,0 +1,210 @@ +package httpcacheworking + +import ( + "bufio" + "bytes" + "context" + "github.com/neilotoole/sq/libsq/core/cleanup" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "io" + "net/http" + "net/http/httputil" + "os" + "path/filepath" + "sync" +) + +// NewRespCache returns a new instance that stores responses in cacheDir. +// The caller should call RespCache.Close when finished with the cache. +func NewRespCache(cacheDir string) *RespCache { + c := &RespCache{ + Dir: cacheDir, + //Header: filepath.Join(cacheDir, "header"), + //Body: filepath.Join(cacheDir, "body"), + clnup: cleanup.New(), + } + return c +} + +// RespCache is a cache for a single http.Response. The response is +// stored in two files, one for the header and one for the body. +// The caller should call RespCache.Close when finished with the cache. +type RespCache struct { + mu sync.Mutex + clnup *cleanup.Cleanup + + Dir string +} + +// Paths returns the paths to the header and body files for req. +// It is not guaranteed that they exist. +func (rc *RespCache) Paths(req *http.Request) (header, body string) { + if req == nil || req.Method == http.MethodGet { + return filepath.Join(rc.Dir, "header"), filepath.Join(rc.Dir, "body") + } + + return filepath.Join(rc.Dir, req.Method+"_header"), + filepath.Join(rc.Dir, req.Method+"_body") +} + +// Exists returns true if the cache contains a response for req. +func (rc *RespCache) Exists(req *http.Request) bool { + rc.mu.Lock() + defer rc.mu.Unlock() + + fpHeader, _ := rc.Paths(req) + fi, err := os.Stat(fpHeader) + if err != nil { + return false + } + return fi.Size() > 0 +} + +// Get returns the cached http.Response for req if present, and nil +// otherwise. +func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response, error) { + rc.mu.Lock() + defer rc.mu.Unlock() + + fpHeader, fpBody := rc.Paths(req) + + if !ioz.FileAccessible(fpHeader) { + return nil, nil + } + + headerBytes, err := os.ReadFile(fpHeader) + if err != nil { + return nil, err + } + + bodyFile, err := os.Open(fpBody) + if err != nil { + lg.FromContext(ctx).Error("failed to open cached response body", + lga.File, fpBody, lga.Err, err) + return nil, err + } + + // We need to explicitly close bodyFile at some later point. It won't be + // closed via a call to http.Response.Body.Close(). + rc.clnup.AddC(bodyFile) + // TODO: consider adding contextio.NewReader? + concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) + return http.ReadResponse(bufio.NewReader(concatRdr), req) +} + +// Close closes the cache, freeing any resources it holds. Note that +// it does not delete the cache: for that, see RespCache.Delete. +func (rc *RespCache) Close() error { + rc.mu.Lock() + defer rc.mu.Unlock() + + err := rc.clnup.Run() + rc.clnup = cleanup.New() + return err +} + +// Delete deletes the cache entries from disk. +func (rc *RespCache) Delete(ctx context.Context) error { + if rc == nil { + return nil + } + rc.mu.Lock() + defer rc.mu.Unlock() + + return rc.doDelete(ctx) +} + +func (rc *RespCache) doDelete(ctx context.Context) error { + cleanErr := rc.clnup.Run() + rc.clnup = cleanup.New() + deleteErr := errz.Wrap(os.RemoveAll(rc.Dir), "delete cache dir") + err := errz.Combine(cleanErr, deleteErr) + if err != nil { + lg.FromContext(ctx).Error("Delete cache dir", + lga.Dir, rc.Dir, lga.Err, err) + return err + } + + lg.FromContext(ctx).Info("Deleted cache dir", lga.Dir, rc.Dir) + return nil +} + +const msgDeleteCache = "Delete HTTP response cache" + +// Write writes resp to the cache. If copyWrtr is non-nil, the response +// bytes are copied to that destination also. +func (rc *RespCache) Write(ctx context.Context, resp *http.Response, copyWrtr io.WriteCloser) error { + rc.mu.Lock() + defer rc.mu.Unlock() + + log := lg.FromContext(ctx) + log.Debug("huzzah in write") + + err := rc.doWrite(ctx, resp, copyWrtr) + if err != nil { + lg.WarnIfError(lg.FromContext(ctx), msgDeleteCache, rc.doDelete(ctx)) + } + return err +} + +func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, copyWrtr io.WriteCloser) error { + log := lg.FromContext(ctx) + + if err := ioz.RequireDir(rc.Dir); err != nil { + return err + } + + fpHeader, fpBody := rc.Paths(resp.Request) + + headerBytes, err := httputil.DumpResponse(resp, false) + if err != nil { + return err + } + + if _, err = ioz.WriteToFile(ctx, fpHeader, bytes.NewReader(headerBytes)); err != nil { + return err + } + + cacheFile, err := os.OpenFile(fpBody, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) + if err != nil { + return err + } + + var cr io.Reader + if copyWrtr == nil { + cr = contextio.NewReader(ctx, resp.Body) + } else { + tr := io.TeeReader(resp.Body, copyWrtr) + cr = contextio.NewReader(ctx, tr) + } + + //if copyWrtr != nil { + // cr = io.TeeReader(cr, copyWrtr) + //} + var written int64 + written, err = io.Copy(cacheFile, cr) + if err != nil { + lg.WarnIfCloseError(log, "Close cache body file", cacheFile) + return err + } + if copyWrtr != nil { + lg.WarnIfCloseError(log, "Close copy writer", copyWrtr) + } + + if err = cacheFile.Close(); err != nil { + return err + } + + log.Info("Wrote HTTP response to cache", lga.File, fpBody, lga.Size, written) + cacheFile, err = os.Open(fpBody) + if err != nil { + return err + } + + resp.Body = cacheFile + return nil +} From 72898de3d8f9e71d33f4ad9937555f41a4aac090 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 14:23:58 -0700 Subject: [PATCH 098/195] saving progress: httpcacheworking --- libsq/source/dl.go | 159 ++++++++++++++++++++++++++++++++++++++++ libsq/source/dl_test.go | 78 ++++++++++++++++++++ 2 files changed, 237 insertions(+) create mode 100644 libsq/source/dl.go create mode 100644 libsq/source/dl_test.go diff --git a/libsq/source/dl.go b/libsq/source/dl.go new file mode 100644 index 000000000..9eea6bcda --- /dev/null +++ b/libsq/source/dl.go @@ -0,0 +1,159 @@ +package source + +import ( + "context" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/ioz/httpcache" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "io" + "log/slog" + "net/http" + "sync" +) + +// newDownloader creates a new downloader using cacheDir for the given url. +func newDownloader2(cacheDir, userAgent, dlURL string) (*downloader2, error) { + //dv := diskv.New(diskv.Options{ + // BasePath: filepath.Join(cacheDir, "cache"), + // TempDir: filepath.Join(cacheDir, "working"), + // CacheSizeMax: 10000 * 1024 * 1024, // 10000MB + //}) + if err := ioz.RequireDir(cacheDir); err != nil { + return nil, err + } + + //dc := diskcache.NewWithDiskv(dv) + rc := httpcache.NewRespCache(cacheDir) + tp := httpcache.NewTransport(rc) + + //respCache := httpcache.NewRespCache(cacheDir) + //tp.RespCache = respCache + //tp.BodyFilepath = filepath.Join(cacheDir, "body.data") + + c := &http.Client{Transport: tp} + + return &downloader2{ + c: c, + //dc: dc, + //dv: dv, + cacheDir: cacheDir, + url: dlURL, + userAgent: userAgent, + tp: tp, + }, nil +} + +type downloader2 struct { + c *http.Client + mu sync.Mutex + userAgent string + cacheDir string + url string + tp *httpcache.Transport +} + +func (d *downloader2) log(log *slog.Logger) *slog.Logger { + return log.With(lga.URL, d.url, lga.Dir, d.cacheDir) +} + +// ClearCache clears the cache dir. +func (d *downloader2) ClearCache(ctx context.Context) error { + d.mu.Lock() + defer d.mu.Unlock() + + if err := d.tp.RespCache.Delete(ctx); err != nil { + //log.Error("Failed to delete cache dir", lga.Dir, d.cacheDir, lga.Err, err) + return errz.Wrapf(err, "failed to clear cache dir: %s", d.cacheDir) + } + + return ioz.RequireDir(d.cacheDir) +} + +func (d *downloader2) Download(ctx context.Context, dest io.Writer) (written int64, fp string, err error) { + d.mu.Lock() + defer d.mu.Unlock() + + log := d.log(lg.FromContext(ctx)) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil) + if d.userAgent != "" { + req.Header.Set("User-Agent", d.userAgent) + } + + isCached := d.tp.IsCached(req) + _ = isCached + + isFresh := d.tp.IsFresh(req) + _ = isFresh + + resp, err := d.c.Do(req) + if err != nil { + return written, "", errz.Wrapf(err, "download failed for: %s", d.url) + } + defer func() { + if resp != nil && resp.Body != nil { + lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) + } + }() + + written, err = io.Copy( + contextio.NewWriter(ctx, dest), + contextio.NewReader(ctx, resp.Body), + ) + + return written, "", err +} + +func (d *downloader2) Download2(ctx context.Context, dest io.Writer) (written int64, fp string, err error) { + d.mu.Lock() + defer d.mu.Unlock() + + log := d.log(lg.FromContext(ctx)) + _ = log + + var destWrtr io.WriteCloser + var ok bool + if destWrtr, ok = dest.(io.WriteCloser); !ok { + destWrtr = ioz.WriteCloser(dest) + } + + log.Debug("huzzah Download2") + req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil) + if d.userAgent != "" { + req.Header.Set("User-Agent", d.userAgent) + } + + isCached := d.tp.IsCached(req) + _ = isCached + + isFresh := d.tp.IsFresh(req) + _ = isFresh + + var gotFp string + var gotErr error + //buf := &bytes.Buffer{} + cb := httpcache.CallbackHandler{ + HandleCached: func(cachedFilepath string) error { + gotFp = cachedFilepath + return nil + }, + HandleUncached: func() (wc io.WriteCloser, errFn func(error), err error) { + return destWrtr, func(err error) { + gotErr = err + }, nil + }, + HandleError: func(err error) { + gotErr = err + }, + } + + d.tp.Fetch(req, cb) + _ = gotFp + _ = gotErr + + return written, "", err +} diff --git a/libsq/source/dl_test.go b/libsq/source/dl_test.go new file mode 100644 index 000000000..e2c8adb61 --- /dev/null +++ b/libsq/source/dl_test.go @@ -0,0 +1,78 @@ +package source + +import ( + "bytes" + "context" + "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/stretchr/testify/require" + "net/url" + "path" + "path/filepath" + "testing" +) + +func TestDownloader2_Download(t *testing.T) { + log := slogt.New(t) + ctx := lg.NewContext(context.Background(), log) + const dlURL = urlActorCSV + const wantContentLength = sizeActorCSV + u, err := url.Parse(dlURL) + require.NoError(t, err) + wantFilename := path.Base(u.Path) + require.Equal(t, "actor.csv", wantFilename) + + cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-2")) + require.NoError(t, err) + t.Logf("cacheDir: %s", cacheDir) + + log.Debug("huzzah") + dl, err := newDownloader2(cacheDir, "sq/dev", dlURL) + require.NoError(t, err) + //require.NoError(t, dl.ClearCache(ctx)) + + buf := &bytes.Buffer{} + written, cachedFp, err := dl.Download2(ctx, buf) + _ = written + _ = cachedFp + require.NoError(t, err) + //require.Equal(t, wantContentLength, written) + //require.Equal(t, wantContentLength, int64(buf.Len())) + + buf.Reset() + written, cachedFp, err = dl.Download2(ctx, buf) + require.NoError(t, err) + //require.Equal(t, wantContentLength, written) + //require.Equal(t, wantContentLength, int64(buf.Len())) +} + +func TestDownloader2_Download_Legacy(t *testing.T) { + ctx := lg.NewContext(context.Background(), slogt.New(t)) + const dlURL = urlActorCSV + const wantContentLength = sizeActorCSV + u, err := url.Parse(dlURL) + require.NoError(t, err) + wantFilename := path.Base(u.Path) + require.Equal(t, "actor.csv", wantFilename) + + cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-2")) + require.NoError(t, err) + t.Logf("cacheDir: %s", cacheDir) + + dl, err := newDownloader2(cacheDir, "sq/dev", dlURL) + require.NoError(t, err) + //require.NoError(t, dl.ClearCache(ctx)) + + buf := &bytes.Buffer{} + written, cachedFp, err := dl.Download2(ctx, buf) + _ = cachedFp + require.NoError(t, err) + require.Equal(t, wantContentLength, written) + require.Equal(t, wantContentLength, int64(buf.Len())) + + buf.Reset() + written, cachedFp, err = dl.Download(ctx, buf) + require.NoError(t, err) + require.Equal(t, wantContentLength, written) + require.Equal(t, wantContentLength, int64(buf.Len())) +} From 5dcb53101b07d8b9c770eefdf726b8ab915bd969 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 12 Dec 2023 17:09:54 -0700 Subject: [PATCH 099/195] wip: refactoring httpcache --- libsq/core/ioz/httpcache/download_test.go | 82 + libsq/core/ioz/httpcache/httpcache.go | 636 +--- libsq/core/ioz/httpcache/httpcache_test.go | 2953 ++++++++--------- libsq/core/ioz/httpcache/httpz.go | 329 +- libsq/core/ioz/httpcacheworking/LICENSE.txt | 7 - libsq/core/ioz/httpcacheworking/README.md | 42 - libsq/core/ioz/httpcacheworking/httpcache.go | 717 ---- .../ioz/httpcacheworking/httpcache_test.go | 1486 --------- libsq/core/ioz/httpcacheworking/httpz.go | 88 - libsq/core/ioz/httpcacheworking/respcache.go | 210 -- libsq/source/dl.go | 13 +- 11 files changed, 1994 insertions(+), 4569 deletions(-) create mode 100644 libsq/core/ioz/httpcache/download_test.go delete mode 100644 libsq/core/ioz/httpcacheworking/LICENSE.txt delete mode 100644 libsq/core/ioz/httpcacheworking/README.md delete mode 100644 libsq/core/ioz/httpcacheworking/httpcache.go delete mode 100644 libsq/core/ioz/httpcacheworking/httpcache_test.go delete mode 100644 libsq/core/ioz/httpcacheworking/httpz.go delete mode 100644 libsq/core/ioz/httpcacheworking/respcache.go diff --git a/libsq/core/ioz/httpcache/download_test.go b/libsq/core/ioz/httpcache/download_test.go new file mode 100644 index 000000000..c8584fe74 --- /dev/null +++ b/libsq/core/ioz/httpcache/download_test.go @@ -0,0 +1,82 @@ +package httpcache_test + +import ( + "bytes" + "context" + "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/httpcache" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/stretchr/testify/require" + "io" + "os" + "path/filepath" + "testing" +) + +const ( + urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" + urlActorCSV = "https://sq.io/testdata/actor.csv" + sizeActorCSV = int64(7641) + sizeGzipActorCSV = int64(1968) +) + +func TestTransport_Fetch(t *testing.T) { + log := slogt.New(t) + ctx := lg.NewContext(context.Background(), log) + const dlURL = urlActorCSV + + cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-2")) + require.NoError(t, err) + t.Logf("cacheDir: %s", cacheDir) + + dl := httpcache.NewTransport(cacheDir, httpcache.OptUserAgent("sq/dev")) + require.NoError(t, dl.Delete(ctx)) + + var ( + destBuf = &bytes.Buffer{} + gotFp string + gotErr error + ) + reset := func() { + destBuf.Reset() + gotFp = "" + gotErr = nil + } + + h := httpcache.Handler{ + Cached: func(cachedFilepath string) error { + gotFp = cachedFilepath + return nil + }, + Uncached: func() (wc io.WriteCloser, errFn func(error), err error) { + return ioz.WriteCloser(destBuf), + func(err error) { + gotErr = err + }, + nil + }, + Error: func(err error) { + gotErr = err + }, + } + + //req, err := http.NewRequestWithContext(ctx, http.MethodGet, dlURL, nil) + ////if d.userAgent != "" { + //// req.Header.Set("User-Agent", d.userAgent) + ////} + dl.Fetch(ctx, dlURL, h) + require.NoError(t, gotErr) + require.Empty(t, gotFp) + require.Equal(t, sizeActorCSV, int64(destBuf.Len())) + + reset() + dl.Fetch(ctx, dlURL, h) + require.NoError(t, gotErr) + require.Equal(t, 0, destBuf.Len()) + require.NotEmpty(t, gotFp) + gotFileBytes, err := os.ReadFile(gotFp) + require.NoError(t, err) + require.Equal(t, sizeActorCSV, int64(len(gotFileBytes))) + +} diff --git a/libsq/core/ioz/httpcache/httpcache.go b/libsq/core/ioz/httpcache/httpcache.go index 41f63706d..432d8464b 100644 --- a/libsq/core/ioz/httpcache/httpcache.go +++ b/libsq/core/ioz/httpcache/httpcache.go @@ -10,17 +10,14 @@ package httpcache import ( "bufio" - "bytes" - "errors" + "context" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "io" - "io/ioutil" "net/http" "os" - "strings" - "time" ) const ( @@ -31,13 +28,35 @@ const ( XFromCache = "X-From-Cache" ) -// TransportOpt is a configuration option for creating a new Transport -type TransportOpt func(t *Transport) +// Opt is a configuration option for creating a new Transport. +type Opt func(t *Transport) -// MarkCachedResponsesOpt configures a transport by setting MarkCachedResponses to true -func MarkCachedResponsesOpt(markCachedResponses bool) TransportOpt { +// OptMarkCacheResponses configures a Transport by setting +// Transport.markCachedResponses to true. +func OptMarkCacheResponses(markCachedResponses bool) Opt { return func(t *Transport) { - t.MarkCachedResponses = markCachedResponses + t.markCachedResponses = markCachedResponses + } +} + +// OptInsecureSkipVerify configures a Transport to skip TLS verification. +func OptInsecureSkipVerify(insecureSkipVerify bool) Opt { + return func(t *Transport) { + t.InsecureSkipVerify = insecureSkipVerify + } +} + +// OptDisableCaching disables the cache. +func OptDisableCaching(disable bool) Opt { + return func(t *Transport) { + t.disableCaching = disable + } +} + +// OptUserAgent sets the User-Agent header on requests. +func OptUserAgent(userAgent string) Opt { + return func(t *Transport) { + t.userAgent = userAgent } } @@ -46,126 +65,92 @@ func MarkCachedResponsesOpt(markCachedResponses bool) TransportOpt { // to repeated requests allowing servers to return 304 / Not Modified type Transport struct { // The RoundTripper interface actually used to make requests - // If nil, http.DefaultTransport is used - Transport http.RoundTripper + // If nil, http.DefaultTransport is used. + transport http.RoundTripper - RespCache *RespCache + // respCache is the cache used to store responses. + respCache *RespCache - // MarkCachedResponses, if true, indicates that responses returned from the + // markCachedResponses, if true, indicates that responses returned from the // cache will be given an extra header, X-From-Cache - MarkCachedResponses bool + markCachedResponses bool + + InsecureSkipVerify bool + + userAgent string + + disableCaching bool } // NewTransport returns a new Transport with the provided Cache and options. If // KeyFunc is not specified in opts then DefaultKeyFunc is used. -func NewTransport(rc *RespCache, opts ...TransportOpt) *Transport { +func NewTransport(cacheDir string, opts ...Opt) *Transport { t := &Transport{ - RespCache: rc, - MarkCachedResponses: true, + markCachedResponses: true, + disableCaching: false, + InsecureSkipVerify: false, } for _, opt := range opts { opt(t) } - return t -} - -// Client returns an *http.Client that caches responses. -func (t *Transport) Client() *http.Client { - return &http.Client{Transport: t} -} -// varyMatches will return false unless all the cached values for the -// headers listed in Vary match the new request -func varyMatches(cachedResp *http.Response, req *http.Request) bool { - for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { - header = http.CanonicalHeaderKey(header) - if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) { - return false - } + if !t.disableCaching { + t.respCache = NewRespCache(cacheDir) } - return true + return t } -// IsCached returns true if there is a cache entry for req. This does not -// guarantee that the cache entry is fresh. -func (t *Transport) IsCached(req *http.Request) bool { - return t.RespCache.Exists(req) +type Handler struct { + Cached func(cachedFilepath string) error + Uncached func() (wc io.WriteCloser, errFn func(error), err error) + Error func(err error) } -// IsFresh returns true if there is a fresh cache entry for req. -func (t *Transport) IsFresh(req *http.Request) bool { - ctx := req.Context() - log := lg.FromContext(ctx) - - if !isCacheable(req) { - return false - } - - if !t.RespCache.Exists(req) { - return false - } - - fpHeader, _ := t.RespCache.Paths(req) - f, err := os.Open(fpHeader) +func (t *Transport) Fetch(ctx context.Context, dlURL string, h Handler) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, dlURL, nil) if err != nil { - log.Error("Failed to open cached response header file", lga.File, fpHeader, lga.Err, err) - return false + h.Error(err) + return } - - defer lg.WarnIfCloseError(log, "Close cached response header", f) - - cachedResp, err := ReadResponse(bufio.NewReader(f), nil, true) - if err != nil { - log.Error("Failed to read cached response", lga.Err, err) - return false + if t.userAgent != "" { + req.Header.Set("User-Agent", t.userAgent) } - freshness := getFreshness(cachedResp.Header, req.Header) - return freshness == fresh + t.FetchWith(req, h) } -func isCacheable(req *http.Request) bool { - return (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get("range") == "" -} - -type CallbackHandler struct { - HandleCached func(cachedFilepath string) error - HandleUncached func() (wc io.WriteCloser, errFn func(error), err error) - HandleError func(err error) -} - -func (t *Transport) Fetch(req *http.Request, cb CallbackHandler) { +func (t *Transport) FetchWith(req *http.Request, cb Handler) { ctx := req.Context() log := lg.FromContext(ctx) log.Info("Fetching download", lga.URL, req.URL.String()) _ = log - _, fpBody := t.RespCache.Paths(req) + _, fpBody := t.respCache.Paths(req) if t.IsFresh(req) { - _ = cb.HandleCached(fpBody) + _ = cb.Cached(fpBody) return } var err error - cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" + cacheable := t.isCacheable(req) var cachedResp *http.Response if cacheable { - cachedResp, err = t.RespCache.Get(req.Context(), req) + cachedResp, err = t.respCache.Get(req.Context(), req) } else { // Need to invalidate an existing value - if err = t.RespCache.Delete(req.Context()); err != nil { - cb.HandleError(err) + if err = t.respCache.Delete(req.Context()); err != nil { + cb.Error(err) return } } - transport := t.Transport + transport := t.transport if transport == nil { transport = http.DefaultTransport } var resp *http.Response if cacheable && cachedResp != nil && err == nil { - if t.MarkCachedResponses { + if t.markCachedResponses { cachedResp.Header.Set(XFromCache, "1") } @@ -173,7 +158,7 @@ func (t *Transport) Fetch(req *http.Request, cb CallbackHandler) { // Can only use cached value if the new request doesn't Vary significantly freshness := getFreshness(cachedResp.Header, req.Header) if freshness == fresh { - _ = cb.HandleCached(fpBody) + _ = cb.Cached(fpBody) return } @@ -212,14 +197,14 @@ func (t *Transport) Fetch(req *http.Request, cb CallbackHandler) { // In case of transport failure and stale-if-error activated, returns cached content // when available log.Warn("Returning cached response due to transport failure", lga.Err, err) - cb.HandleCached(fpBody) + cb.Cached(fpBody) return } else { if err != nil || resp.StatusCode != http.StatusOK { - lg.WarnIfError(log, msgDeleteCache, t.RespCache.Delete(req.Context())) + lg.WarnIfError(log, msgDeleteCache, t.respCache.Delete(req.Context())) } if err != nil { - cb.HandleError(err) + cb.Error(err) return } } @@ -230,7 +215,7 @@ func (t *Transport) Fetch(req *http.Request, cb CallbackHandler) { } else { resp, err = transport.RoundTrip(req) if err != nil { - cb.HandleError(err) + cb.Error(err) return } } @@ -245,473 +230,100 @@ func (t *Transport) Fetch(req *http.Request, cb CallbackHandler) { resp.Header.Set(fakeHeader, reqValue) } } - switch req.Method { - //case "GET": - // // Delay caching until EOF is reached. - // resp.Body = &cachingReadCloser{ - // R: resp.Body, - // OnEOF: func(r io.Reader) { - // resp := *resp - // resp.Body = ioutil.NopCloser(r) - // if err := t.RespCache.Write(req.Context(), &resp, nil); err != nil { - // log.Error("failed to write download cache", lga.Err, err) - // } - // }, - // } - default: - copyWrtr, errFn, err := cb.HandleUncached() - if err != nil { - cb.HandleError(err) - return - } - if err = t.RespCache.Write(req.Context(), resp, copyWrtr); err != nil { - log.Error("failed to write download cache", lga.Err, err) - errFn(err) - cb.HandleError(err) - } + copyWrtr, errFn, err := cb.Uncached() + if err != nil { + cb.Error(err) return } + + if err = t.respCache.Write(req.Context(), resp, copyWrtr); err != nil { + log.Error("failed to write download cache", lga.Err, err) + errFn(err) + cb.Error(err) + } + return } else { - lg.WarnIfError(log, "Delete resp cache", t.RespCache.Delete(req.Context())) + lg.WarnIfError(log, "Delete resp cache", t.respCache.Delete(req.Context())) } - // It's not cacheable, so we need to write it to the copyWrtr - copyWrtr, errFn, err := cb.HandleUncached() + // It's not cacheable, so we need to write it to the copyWrtr. + copyWrtr, errFn, err := cb.Uncached() if err != nil { - cb.HandleError(err) + cb.Error(err) return } cr := contextio.NewReader(ctx, resp.Body) _, err = io.Copy(copyWrtr, cr) if err != nil { errFn(err) - cb.HandleError(err) + cb.Error(err) return } if err = copyWrtr.Close(); err != nil { - cb.HandleError(err) + cb.Error(err) return } return } -// RoundTrip takes a Request and returns a Response -// -// If there is a fresh Response already in cache, then it will be returned without connecting to -// the server. -// -// If there is a stale Response, then any validators it contains will be set on the new request -// to give the server a chance to respond with NotModified. If this happens, then the cached Response -// will be returned. -func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { - log := lg.FromContext(req.Context()) - - cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" - var cachedResp *http.Response - if cacheable { - cachedResp, err = t.RespCache.Get(req.Context(), req) - } else { - // Need to invalidate an existing value - if err = t.RespCache.Delete(req.Context()); err != nil { - return nil, err - } - } - - transport := t.Transport - if transport == nil { - transport = http.DefaultTransport - } - - if cacheable && cachedResp != nil && err == nil { - if t.MarkCachedResponses { - cachedResp.Header.Set(XFromCache, "1") - } - - if varyMatches(cachedResp, req) { - // Can only use cached value if the new request doesn't Vary significantly - freshness := getFreshness(cachedResp.Header, req.Header) - if freshness == fresh { - return cachedResp, nil - } - - if freshness == stale { - var req2 *http.Request - // Add validators if caller hasn't already done so - etag := cachedResp.Header.Get("etag") - if etag != "" && req.Header.Get("etag") == "" { - req2 = cloneRequest(req) - req2.Header.Set("if-none-match", etag) - } - lastModified := cachedResp.Header.Get("last-modified") - if lastModified != "" && req.Header.Get("last-modified") == "" { - if req2 == nil { - req2 = cloneRequest(req) - } - req2.Header.Set("if-modified-since", lastModified) - } - if req2 != nil { - req = req2 - } - } - } - - resp, err = transport.RoundTrip(req) - if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified { - // Replace the 304 response with the one from cache, but update with some new headers - endToEndHeaders := getEndToEndHeaders(resp.Header) - for _, header := range endToEndHeaders { - cachedResp.Header[header] = resp.Header[header] - } - resp = cachedResp - } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && - req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) { - // In case of transport failure and stale-if-error activated, returns cached content - // when available - return cachedResp, nil - } else { - if err != nil || resp.StatusCode != http.StatusOK { - lg.WarnIfError(log, msgDeleteCache, t.RespCache.Delete(req.Context())) - } - if err != nil { - return nil, err - } - } - } else { - reqCacheControl := parseCacheControl(req.Header) - if _, ok := reqCacheControl["only-if-cached"]; ok { - resp = newGatewayTimeoutResponse(req) - } else { - resp, err = transport.RoundTrip(req) - if err != nil { - return nil, err - } - } - } - - if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { - for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { - varyKey = http.CanonicalHeaderKey(varyKey) - fakeHeader := "X-Varied-" + varyKey - reqValue := req.Header.Get(varyKey) - if reqValue != "" { - resp.Header.Set(fakeHeader, reqValue) - } - } - switch req.Method { - case "GET": - // Delay caching until EOF is reached. - resp.Body = &cachingReadCloser{ - R: resp.Body, - OnEOF: func(r io.Reader) { - resp := *resp - resp.Body = ioutil.NopCloser(r) - if err := t.RespCache.Write(req.Context(), &resp, nil); err != nil { - log.Error("failed to write download cache", lga.Err, err) - } - }, - } - default: - if err = t.RespCache.Write(req.Context(), resp, nil); err != nil { - log.Error("failed to write download cache", lga.Err, err) - } - } - } else { - lg.WarnIfError(log, "Delete resp cache", t.RespCache.Delete(req.Context())) - } - return resp, nil -} - -// ErrNoDateHeader indicates that the HTTP headers contained no Date header. -var ErrNoDateHeader = errors.New("no Date header") - -// Date parses and returns the value of the Date header. -func Date(respHeaders http.Header) (date time.Time, err error) { - dateHeader := respHeaders.Get("date") - if dateHeader == "" { - err = ErrNoDateHeader - return - } - - return time.Parse(time.RFC1123, dateHeader) -} - -type realClock struct{} - -func (c *realClock) since(d time.Time) time.Duration { - return time.Since(d) -} - -type timer interface { - since(d time.Time) time.Duration +func (t *Transport) getClient() *http.Client { + return ioz.NewHTTPClient(t.InsecureSkipVerify) } -var clock timer = &realClock{} - -// getFreshness will return one of fresh/stale/transparent based on the cache-control -// values of the request and the response -// -// fresh indicates the response can be returned -// stale indicates that the response needs validating before it is returned -// transparent indicates the response should not be used to fulfil the request -// -// Because this is only a private cache, 'public' and 'private' in cache-control aren't -// signficant. Similarly, smax-age isn't used. -func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { - respCacheControl := parseCacheControl(respHeaders) - reqCacheControl := parseCacheControl(reqHeaders) - if _, ok := reqCacheControl["no-cache"]; ok { - return transparent - } - if _, ok := respCacheControl["no-cache"]; ok { - return stale - } - if _, ok := reqCacheControl["only-if-cached"]; ok { - return fresh - } - - date, err := Date(respHeaders) - if err != nil { - return stale - } - currentAge := clock.since(date) - - var lifetime time.Duration - var zeroDuration time.Duration - - // If a response includes both an Expires header and a max-age directive, - // the max-age directive overrides the Expires header, even if the Expires header is more restrictive. - if maxAge, ok := respCacheControl["max-age"]; ok { - lifetime, err = time.ParseDuration(maxAge + "s") - if err != nil { - lifetime = zeroDuration - } - } else { - expiresHeader := respHeaders.Get("Expires") - if expiresHeader != "" { - expires, err := time.Parse(time.RFC1123, expiresHeader) - if err != nil { - lifetime = zeroDuration - } else { - lifetime = expires.Sub(date) - } - } - } - - if maxAge, ok := reqCacheControl["max-age"]; ok { - // the client is willing to accept a response whose age is no greater than the specified time in seconds - lifetime, err = time.ParseDuration(maxAge + "s") - if err != nil { - lifetime = zeroDuration - } - } - if minfresh, ok := reqCacheControl["min-fresh"]; ok { - // the client wants a response that will still be fresh for at least the specified number of seconds. - minfreshDuration, err := time.ParseDuration(minfresh + "s") - if err == nil { - currentAge = time.Duration(currentAge + minfreshDuration) - } - } - - if maxstale, ok := reqCacheControl["max-stale"]; ok { - // Indicates that the client is willing to accept a response that has exceeded its expiration time. - // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded - // its expiration time by no more than the specified number of seconds. - // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age. - // - // Responses served only because of a max-stale value are supposed to have a Warning header added to them, - // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different - // return-value available here. - if maxstale == "" { - return fresh - } - maxstaleDuration, err := time.ParseDuration(maxstale + "s") - if err == nil { - currentAge = time.Duration(currentAge - maxstaleDuration) - } - } - - if lifetime > currentAge { - return fresh +// Delete deletes the cache. +func (t *Transport) Delete(ctx context.Context) error { + if t.respCache != nil { + return t.respCache.Delete(ctx) } - - return stale + return nil } -// Returns true if either the request or the response includes the stale-if-error -// cache control extension: https://tools.ietf.org/html/rfc5861 -func canStaleOnError(respHeaders, reqHeaders http.Header) bool { - respCacheControl := parseCacheControl(respHeaders) - reqCacheControl := parseCacheControl(reqHeaders) - - var err error - lifetime := time.Duration(-1) - - if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok { - if staleMaxAge != "" { - lifetime, err = time.ParseDuration(staleMaxAge + "s") - if err != nil { - return false - } - } else { - return true - } - } - if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok { - if staleMaxAge != "" { - lifetime, err = time.ParseDuration(staleMaxAge + "s") - if err != nil { - return false - } - } else { - return true - } - } - - if lifetime >= 0 { - date, err := Date(respHeaders) - if err != nil { - return false - } - currentAge := clock.since(date) - if lifetime > currentAge { - return true - } +// IsCached returns true if there is a cache entry for req. This does not +// guarantee that the cache entry is fresh. See also: [Transport.IsFresh]. +func (t *Transport) IsCached(req *http.Request) bool { + if t.disableCaching { + return false } - - return false + return t.respCache.Exists(req) } -func getEndToEndHeaders(respHeaders http.Header) []string { - // These headers are always hop-by-hop - hopByHopHeaders := map[string]struct{}{ - "Connection": {}, - "Keep-Alive": {}, - "Proxy-Authenticate": {}, - "Proxy-Authorization": {}, - "Te": {}, - "Trailers": {}, - "Transfer-Encoding": {}, - "Upgrade": {}, - } - - for _, extra := range strings.Split(respHeaders.Get("connection"), ",") { - // any header listed in connection, if present, is also considered hop-by-hop - if strings.Trim(extra, " ") != "" { - hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{} - } - } - endToEndHeaders := []string{} - for respHeader := range respHeaders { - if _, ok := hopByHopHeaders[respHeader]; !ok { - endToEndHeaders = append(endToEndHeaders, respHeader) - } - } - return endToEndHeaders -} +// IsFresh returns true if there is a fresh cache entry for req. +func (t *Transport) IsFresh(req *http.Request) bool { + ctx := req.Context() + log := lg.FromContext(ctx) -func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) { - if _, ok := respCacheControl["no-store"]; ok { + if !t.isCacheable(req) { return false } - if _, ok := reqCacheControl["no-store"]; ok { + + if !t.respCache.Exists(req) { return false } - return true -} -func newGatewayTimeoutResponse(req *http.Request) *http.Response { - var braw bytes.Buffer - braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n") - resp, err := http.ReadResponse(bufio.NewReader(&braw), req) + fpHeader, _ := t.respCache.Paths(req) + f, err := os.Open(fpHeader) if err != nil { - panic(err) - } - return resp -} - -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map. -// (This function copyright goauth2 authors: https://code.google.com/p/goauth2) -func cloneRequest(r *http.Request) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - if ctx := r.Context(); ctx != nil { - r2 = r2.WithContext(ctx) - } - // deep copy of the Header - r2.Header = make(http.Header) - for k, s := range r.Header { - r2.Header[k] = s + log.Error("Failed to open cached response header file", lga.File, fpHeader, lga.Err, err) + return false } - return r2 -} -type cacheControl map[string]string - -func parseCacheControl(headers http.Header) cacheControl { - cc := cacheControl{} - ccHeader := headers.Get("Cache-Control") - for _, part := range strings.Split(ccHeader, ",") { - part = strings.Trim(part, " ") - if part == "" { - continue - } - if strings.ContainsRune(part, '=') { - keyval := strings.Split(part, "=") - cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",") - } else { - cc[part] = "" - } - } - return cc -} + defer lg.WarnIfCloseError(log, "Close cached response header", f) -// headerAllCommaSepValues returns all comma-separated values (each -// with whitespace trimmed) for header name in headers. According to -// Section 4.2 of the HTTP/1.1 spec -// (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2), -// values from multiple occurrences of a header should be concatenated, if -// the header's value is a comma-separated list. -func headerAllCommaSepValues(headers http.Header, name string) []string { - var vals []string - for _, val := range headers[http.CanonicalHeaderKey(name)] { - fields := strings.Split(val, ",") - for i, f := range fields { - fields[i] = strings.TrimSpace(f) - } - vals = append(vals, fields...) + cachedResp, err := readResponseHeader(bufio.NewReader(f), nil) + if err != nil { + log.Error("Failed to read cached response", lga.Err, err) + return false } - return vals -} -// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF -// handler with a full copy of the content read from R when EOF is -// reached. -type cachingReadCloser struct { - // Underlying ReadCloser. - R io.ReadCloser - // OnEOF is called with a copy of the content of R when EOF is reached. - OnEOF func(io.Reader) - - buf bytes.Buffer // buf stores a copy of the content of R. + freshness := getFreshness(cachedResp.Header, req.Header) + return freshness == fresh } -// Read reads the next len(p) bytes from R or until R is drained. The -// return value n is the number of bytes read. If R has no data to -// return, err is io.EOF and OnEOF is called with a full copy of what -// has been read so far. -func (r *cachingReadCloser) Read(p []byte) (n int, err error) { - n, err = r.R.Read(p) - r.buf.Write(p[:n]) - if err == io.EOF { - r.OnEOF(bytes.NewReader(r.buf.Bytes())) +func (t *Transport) isCacheable(req *http.Request) bool { + if t.disableCaching { + return false } - return n, err -} - -func (r *cachingReadCloser) Close() error { - return r.R.Close() + return (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get("range") == "" } diff --git a/libsq/core/ioz/httpcache/httpcache_test.go b/libsq/core/ioz/httpcache/httpcache_test.go index 48b9a73fc..da7a93250 100644 --- a/libsq/core/ioz/httpcache/httpcache_test.go +++ b/libsq/core/ioz/httpcache/httpcache_test.go @@ -1,1486 +1,1471 @@ package httpcache -import ( - "bytes" - "errors" - "flag" - "github.com/neilotoole/sq/libsq/core/stringz" - "io" - "io/ioutil" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "strconv" - "testing" - "time" -) - -// newTestTransport returns a new Transport using the in-memory cache implementation -func newTestTransport(cacheDir string, opts ...TransportOpt) *Transport { - rc := NewRespCache(cacheDir) - t := NewTransport(rc, opts...) - return t -} - -var s struct { - server *httptest.Server - client http.Client - transport *Transport - done chan struct{} // Closed to unlock infinite handlers. -} - -type fakeClock struct { - elapsed time.Duration -} - -func (c *fakeClock) since(t time.Time) time.Duration { - return c.elapsed -} - -func TestMain(m *testing.M) { - flag.Parse() - setup() - code := m.Run() - teardown() - os.Exit(code) -} - -func setup() { - tp := newTestTransport(filepath.Join(os.TempDir(), stringz.Uniq8())) - client := http.Client{Transport: tp} - s.transport = tp - s.client = client - s.done = make(chan struct{}) - - mux := http.NewServeMux() - s.server = httptest.NewServer(mux) - - mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - })) - - mux.HandleFunc("/method", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - _, _ = w.Write([]byte(r.Method)) - })) - - mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lm := "Fri, 14 Dec 2010 01:01:50 GMT" - if r.Header.Get("if-modified-since") == lm { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("last-modified", lm) - if r.Header.Get("range") == "bytes=4-9" { - w.WriteHeader(http.StatusPartialContent) - _, _ = w.Write([]byte(" text ")) - return - } - _, _ = w.Write([]byte("Some text content")) - })) - - mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "no-store") - })) - - mux.HandleFunc("/etag", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - etag := "124567" - if r.Header.Get("if-none-match") == etag { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("etag", etag) - })) - - mux.HandleFunc("/lastmodified", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lm := "Fri, 14 Dec 2010 01:01:50 GMT" - if r.Header.Get("if-modified-since") == lm { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("last-modified", lm) - })) - - mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Vary", "Accept") - _, _ = w.Write([]byte("Some text content")) - })) - - mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Vary", "Accept, Accept-Language") - _, _ = w.Write([]byte("Some text content")) - })) - mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Add("Vary", "Accept") - w.Header().Add("Vary", "Accept-Language") - _, _ = w.Write([]byte("Some text content")) - })) - mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Vary", "X-Madeup-Header") - _, _ = w.Write([]byte("Some text content")) - })) - - mux.HandleFunc("/cachederror", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - etag := "abc" - if r.Header.Get("if-none-match") == etag { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("etag", etag) - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte("Not found")) - })) - - updateFieldsCounter := 0 - mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Counter", strconv.Itoa(updateFieldsCounter)) - w.Header().Set("Etag", `"e"`) - updateFieldsCounter++ - if r.Header.Get("if-none-match") != "" { - w.WriteHeader(http.StatusNotModified) - return - } - _, _ = w.Write([]byte("Some text content")) - })) - - // Take 3 seconds to return 200 OK (for testing client timeouts). - mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(3 * time.Second) - })) - - mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for { - select { - case <-s.done: - return - default: - _, _ = w.Write([]byte{0}) - } - } - })) -} - -func teardown() { - close(s.done) - s.server.Close() -} - -func resetTest(t testing.TB) { - s.transport.RespCache = NewRespCache(t.TempDir()) - //s.transport.RespCache.Delete() - clock = &realClock{} -} - -// TestCacheableMethod ensures that uncacheable method does not get stored -// in cache and get incorrectly used for a following cacheable method request. -func TestCacheableMethod(t *testing.T) { - resetTest(t) - { - req, err := http.NewRequest("POST", s.server.URL+"/method", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "POST"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/method", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "GET"; got != want { - t.Errorf("got wrong body %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "" { - t.Errorf("XFromCache header isn't blank") - } - } -} - -func TestDontServeHeadResponseToGetRequest(t *testing.T) { - resetTest(t) - url := s.server.URL + "/" - req, err := http.NewRequest(http.MethodHead, url, nil) - if err != nil { - t.Fatal(err) - } - _, err = s.client.Do(req) - if err != nil { - t.Fatal(err) - } - req, err = http.NewRequest(http.MethodGet, url, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.Header.Get(XFromCache) != "" { - t.Errorf("Cache should not match") - } -} - -func TestDontStorePartialRangeInCache(t *testing.T) { - resetTest(t) - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("range", "bytes=4-9") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), " text "; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusPartialContent { - t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "Some text content"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "" { - t.Error("XFromCache header isn't blank") - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "Some text content"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "1" { - t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("range", "bytes=4-9") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), " text "; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusPartialContent { - t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) - } - } -} - -func TestCacheOnlyIfBodyRead(t *testing.T) { - resetTest(t) - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - // We do not read the body - resp.Body.Close() - } - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatalf("XFromCache header isn't blank") - } - } -} - -func TestOnlyReadBodyOnDemand(t *testing.T) { - resetTest(t) - - req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) // This shouldn't hang forever. - if err != nil { - t.Fatal(err) - } - buf := make([]byte, 10) // Only partially read the body. - _, err = resp.Body.Read(buf) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() -} - -func TestGetOnlyIfCachedHit(t *testing.T) { - resetTest(t) - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("cache-control", "only-if-cached") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) - } - } -} - -func TestGetOnlyIfCachedMiss(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("cache-control", "only-if-cached") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - if resp.StatusCode != http.StatusGatewayTimeout { - t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) - } -} - -func TestGetNoStoreRequest(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("Cache-Control", "no-store") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetNoStoreResponse(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetWithEtag(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - // additional assertions to verify that 304 response is converted properly - if resp.StatusCode != http.StatusOK { - t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if _, ok := resp.Header["Connection"]; ok { - t.Fatalf("Connection header isn't absent") - } - } -} - -func TestGetWithLastModified(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } -} - -func TestGetWithVary(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", "text/plain") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") != "Accept" { - t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary")) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - req.Header.Set("Accept", "text/html") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetWithDoubleVary(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", "text/plain") - req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") == "" { - t.Fatalf(`Vary header is blank`) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - req.Header.Set("Accept-Language", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept-Language", "da") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetWith2VaryHeaders(t *testing.T) { - resetTest(t) - // Tests that multiple Vary headers' comma-separated lists are - // merged. See https://github.com/gregjones/httpcache/issues/27. - const ( - accept = "text/plain" - acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" - ) - req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", accept) - req.Header.Set("Accept-Language", acceptLanguage) - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") == "" { - t.Fatalf(`Vary header is blank`) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - req.Header.Set("Accept-Language", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept-Language", "da") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept-Language", acceptLanguage) - req.Header.Set("Accept", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept", "image/png") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } -} - -func TestGetVaryUnused(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", "text/plain") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") == "" { - t.Fatalf(`Vary header is blank`) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } -} - -func TestUpdateFields(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) - if err != nil { - t.Fatal(err) - } - var counter, counter2 string - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - counter = resp.Header.Get("x-counter") - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - counter2 = resp.Header.Get("x-counter") - } - if counter == counter2 { - t.Fatalf(`both "x-counter" values are equal: %v %v`, counter, counter2) - } -} - -// This tests the fix for https://github.com/gregjones/httpcache/issues/74. -// Previously, after validating a cached response, its StatusCode -// was incorrectly being replaced. -func TestCachedErrorsKeepStatus(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/cachederror", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - _, _ = io.Copy(ioutil.Discard, resp.Body) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusNotFound { - t.Fatalf("Status code isn't 404: %d", resp.StatusCode) - } - } -} - -func TestParseCacheControl(t *testing.T) { - resetTest(t) - h := http.Header{} - for range parseCacheControl(h) { - t.Fatal("cacheControl should be empty") - } - - h.Set("cache-control", "no-cache") - { - cc := parseCacheControl(h) - if _, ok := cc["foo"]; ok { - t.Error(`Value "foo" shouldn't exist`) - } - noCache, ok := cc["no-cache"] - if !ok { - t.Fatalf(`"no-cache" value isn't set`) - } - if noCache != "" { - t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) - } - } - h.Set("cache-control", "no-cache, max-age=3600") - { - cc := parseCacheControl(h) - noCache, ok := cc["no-cache"] - if !ok { - t.Fatalf(`"no-cache" value isn't set`) - } - if noCache != "" { - t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) - } - if cc["max-age"] != "3600" { - t.Fatalf(`"max-age" value isn't "3600": %v`, cc["max-age"]) - } - } -} - -func TestNoCacheRequestExpiration(t *testing.T) { - resetTest(t) - respHeaders := http.Header{} - respHeaders.Set("Cache-Control", "max-age=7200") - - reqHeaders := http.Header{} - reqHeaders.Set("Cache-Control", "no-cache") - if getFreshness(respHeaders, reqHeaders) != transparent { - t.Fatal("freshness isn't transparent") - } -} - -func TestNoCacheResponseExpiration(t *testing.T) { - resetTest(t) - respHeaders := http.Header{} - respHeaders.Set("Cache-Control", "no-cache") - respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestReqMustRevalidate(t *testing.T) { - resetTest(t) - // not paying attention to request setting max-stale means never returning stale - // responses, so always acting as if must-revalidate is set - respHeaders := http.Header{} - - reqHeaders := http.Header{} - reqHeaders.Set("Cache-Control", "must-revalidate") - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestRespMustRevalidate(t *testing.T) { - resetTest(t) - respHeaders := http.Header{} - respHeaders.Set("Cache-Control", "must-revalidate") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestFreshExpiration(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 3 * time.Second} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestMaxAge(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=2") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 3 * time.Second} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestMaxAgeZero(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=0") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestBothMaxAge(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=2") - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "max-age=0") - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestMinFreshWithExpires(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "min-fresh=1") - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - reqHeaders = http.Header{} - reqHeaders.Set("cache-control", "min-fresh=2") - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestEmptyMaxStale(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=20") - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "max-stale") - clock = &fakeClock{elapsed: 10 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 60 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } -} - -func TestMaxStaleValue(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=10") - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "max-stale=20") - clock = &fakeClock{elapsed: 5 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 15 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 30 * time.Second} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func containsHeader(headers []string, header string) bool { - for _, v := range headers { - if http.CanonicalHeaderKey(v) == http.CanonicalHeaderKey(header) { - return true - } - } - return false -} - -func TestGetEndToEndHeaders(t *testing.T) { - resetTest(t) - var ( - headers http.Header - end2end []string - ) - - headers = http.Header{} - headers.Set("content-type", "text/html") - headers.Set("te", "deflate") - - end2end = getEndToEndHeaders(headers) - if !containsHeader(end2end, "content-type") { - t.Fatal(`doesn't contain "content-type" header`) - } - if containsHeader(end2end, "te") { - t.Fatal(`doesn't contain "te" header`) - } - - headers = http.Header{} - headers.Set("connection", "content-type") - headers.Set("content-type", "text/csv") - headers.Set("te", "deflate") - end2end = getEndToEndHeaders(headers) - if containsHeader(end2end, "connection") { - t.Fatal(`doesn't contain "connection" header`) - } - if containsHeader(end2end, "content-type") { - t.Fatal(`doesn't contain "content-type" header`) - } - if containsHeader(end2end, "te") { - t.Fatal(`doesn't contain "te" header`) - } - - headers = http.Header{} - end2end = getEndToEndHeaders(headers) - if len(end2end) != 0 { - t.Fatal(`non-zero end2end headers`) - } - - headers = http.Header{} - headers.Set("connection", "content-type") - end2end = getEndToEndHeaders(headers) - if len(end2end) != 0 { - t.Fatal(`non-zero end2end headers`) - } -} - -type transportMock struct { - response *http.Response - err error -} - -func (t transportMock) RoundTrip(req *http.Request) (resp *http.Response, err error) { - return t.response, t.err -} - -func TestStaleIfErrorRequest(t *testing.T) { - resetTest(t) - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := newTestTransport(t.TempDir()) - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - r.Header.Set("Cache-Control", "stale-if-error") - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } -} - -func TestStaleIfErrorRequestLifetime(t *testing.T) { - resetTest(t) - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := newTestTransport(t.TempDir()) - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - r.Header.Set("Cache-Control", "stale-if-error=100") - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - - // Same for http errors - tmock.response = &http.Response{StatusCode: http.StatusInternalServerError} - tmock.err = nil - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - - // If failure last more than max stale, error is returned - clock = &fakeClock{elapsed: 200 * time.Second} - _, err = tp.RoundTrip(r) - if err != tmock.err { - t.Fatalf("got err %v, want %v", err, tmock.err) - } -} - -func TestStaleIfErrorResponse(t *testing.T) { - resetTest(t) - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache, stale-if-error"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := newTestTransport(t.TempDir()) - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } -} - -func TestStaleIfErrorResponseLifetime(t *testing.T) { - resetTest(t) - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache, stale-if-error=100"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := newTestTransport(t.TempDir()) - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - - // If failure last more than max stale, error is returned - clock = &fakeClock{elapsed: 200 * time.Second} - _, err = tp.RoundTrip(r) - if err != tmock.err { - t.Fatalf("got err %v, want %v", err, tmock.err) - } -} - -// This tests the fix for https://github.com/gregjones/httpcache/issues/74. -// Previously, after a stale response was used after encountering an error, -// its StatusCode was being incorrectly replaced. -func TestStaleIfErrorKeepsStatus(t *testing.T) { - resetTest(t) - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusNotFound), - StatusCode: http.StatusNotFound, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := newTestTransport(t.TempDir()) - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - r.Header.Set("Cache-Control", "stale-if-error") - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - if resp.StatusCode != http.StatusNotFound { - t.Fatalf("Status wasn't 404: %d", resp.StatusCode) - } -} - -// Test that http.Client.Timeout is respected when cache transport is used. -// That is so as long as request cancellation is propagated correctly. -// In the past, that required CancelRequest to be implemented correctly, -// but modern http.Client uses Request.Cancel (or request context) instead, -// so we don't have to do anything. -func TestClientTimeout(t *testing.T) { - if testing.Short() { - t.Skip("skipping timeout test in short mode") // Because it takes at least 3 seconds to run. - } - resetTest(t) - - client := &http.Client{ - Transport: newTestTransport(t.TempDir()), - Timeout: time.Second, - } - started := time.Now() - resp, err := client.Get(s.server.URL + "/3seconds") - taken := time.Since(started) - if err == nil { - t.Error("got nil error, want timeout error") - } - if resp != nil { - t.Error("got non-nil resp, want nil resp") - } - if taken >= 2*time.Second { - t.Error("client.Do took 2+ seconds, want < 2 seconds") - } -} +// +//// newTestTransport returns a new Transport using the in-memory cache implementation +//func newTestTransport(cacheDir string, opts ...Opt) *Transport { +// t := NewTransport(cacheDir, opts...) +// return t +//} +// +//var s struct { +// server *httptest.Server +// client http.Client +// transport *Transport +// done chan struct{} // Closed to unlock infinite handlers. +//} +// +//type fakeClock struct { +// elapsed time.Duration +//} +// +//func (c *fakeClock) since(t time.Time) time.Duration { +// return c.elapsed +//} +// +//func TestMain(m *testing.M) { +// flag.Parse() +// setup() +// code := m.Run() +// teardown() +// os.Exit(code) +//} +// +//func setup() { +// tp := newTestTransport(filepath.Join(os.TempDir(), stringz.Uniq8())) +// client := http.Client{Transport: tp} +// s.transport = tp +// s.client = client +// s.done = make(chan struct{}) +// +// mux := http.NewServeMux() +// s.server = httptest.NewServer(mux) +// +// mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Header().Set("Cache-Control", "max-age=3600") +// })) +// +// mux.HandleFunc("/method", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Header().Set("Cache-Control", "max-age=3600") +// _, _ = w.Write([]byte(r.Method)) +// })) +// +// mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// lm := "Fri, 14 Dec 2010 01:01:50 GMT" +// if r.Header.Get("if-modified-since") == lm { +// w.WriteHeader(http.StatusNotModified) +// return +// } +// w.Header().Set("last-modified", lm) +// if r.Header.Get("range") == "bytes=4-9" { +// w.WriteHeader(http.StatusPartialContent) +// _, _ = w.Write([]byte(" text ")) +// return +// } +// _, _ = w.Write([]byte("Some text content")) +// })) +// +// mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Header().Set("Cache-Control", "no-store") +// })) +// +// mux.HandleFunc("/etag", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// etag := "124567" +// if r.Header.Get("if-none-match") == etag { +// w.WriteHeader(http.StatusNotModified) +// return +// } +// w.Header().Set("etag", etag) +// })) +// +// mux.HandleFunc("/lastmodified", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// lm := "Fri, 14 Dec 2010 01:01:50 GMT" +// if r.Header.Get("if-modified-since") == lm { +// w.WriteHeader(http.StatusNotModified) +// return +// } +// w.Header().Set("last-modified", lm) +// })) +// +// mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Header().Set("Cache-Control", "max-age=3600") +// w.Header().Set("Content-Type", "text/plain") +// w.Header().Set("Vary", "Accept") +// _, _ = w.Write([]byte("Some text content")) +// })) +// +// mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Header().Set("Cache-Control", "max-age=3600") +// w.Header().Set("Content-Type", "text/plain") +// w.Header().Set("Vary", "Accept, Accept-Language") +// _, _ = w.Write([]byte("Some text content")) +// })) +// mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Header().Set("Cache-Control", "max-age=3600") +// w.Header().Set("Content-Type", "text/plain") +// w.Header().Add("Vary", "Accept") +// w.Header().Add("Vary", "Accept-Language") +// _, _ = w.Write([]byte("Some text content")) +// })) +// mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Header().Set("Cache-Control", "max-age=3600") +// w.Header().Set("Content-Type", "text/plain") +// w.Header().Set("Vary", "X-Madeup-Header") +// _, _ = w.Write([]byte("Some text content")) +// })) +// +// mux.HandleFunc("/cachederror", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// etag := "abc" +// if r.Header.Get("if-none-match") == etag { +// w.WriteHeader(http.StatusNotModified) +// return +// } +// w.Header().Set("etag", etag) +// w.WriteHeader(http.StatusNotFound) +// _, _ = w.Write([]byte("Not found")) +// })) +// +// updateFieldsCounter := 0 +// mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// w.Header().Set("X-Counter", strconv.Itoa(updateFieldsCounter)) +// w.Header().Set("Etag", `"e"`) +// updateFieldsCounter++ +// if r.Header.Get("if-none-match") != "" { +// w.WriteHeader(http.StatusNotModified) +// return +// } +// _, _ = w.Write([]byte("Some text content")) +// })) +// +// // Take 3 seconds to return 200 OK (for testing client timeouts). +// mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// time.Sleep(3 * time.Second) +// })) +// +// mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// for { +// select { +// case <-s.done: +// return +// default: +// _, _ = w.Write([]byte{0}) +// } +// } +// })) +//} +// +//func teardown() { +// close(s.done) +// s.server.Close() +//} +// +//func resetTest(t testing.TB) { +// s.transport.RespCache = NewRespCache(t.TempDir()) +// //s.transport.RespCache.Delete() +// clock = &realClock{} +//} +// +//// TestCacheableMethod ensures that uncacheable method does not get stored +//// in cache and get incorrectly used for a following cacheable method request. +//func TestCacheableMethod(t *testing.T) { +// resetTest(t) +// { +// req, err := http.NewRequest("POST", s.server.URL+"/method", nil) +// if err != nil { +// t.Fatal(err) +// } +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// var buf bytes.Buffer +// _, err = io.Copy(&buf, resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// err = resp.Body.Close() +// if err != nil { +// t.Fatal(err) +// } +// if got, want := buf.String(), "POST"; got != want { +// t.Errorf("got %q, want %q", got, want) +// } +// if resp.StatusCode != http.StatusOK { +// t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) +// } +// } +// { +// req, err := http.NewRequest("GET", s.server.URL+"/method", nil) +// if err != nil { +// t.Fatal(err) +// } +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// var buf bytes.Buffer +// _, err = io.Copy(&buf, resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// err = resp.Body.Close() +// if err != nil { +// t.Fatal(err) +// } +// if got, want := buf.String(), "GET"; got != want { +// t.Errorf("got wrong body %q, want %q", got, want) +// } +// if resp.StatusCode != http.StatusOK { +// t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) +// } +// if resp.Header.Get(XFromCache) != "" { +// t.Errorf("XFromCache header isn't blank") +// } +// } +//} +// +//func TestDontServeHeadResponseToGetRequest(t *testing.T) { +// resetTest(t) +// url := s.server.URL + "/" +// req, err := http.NewRequest(http.MethodHead, url, nil) +// if err != nil { +// t.Fatal(err) +// } +// _, err = s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// req, err = http.NewRequest(http.MethodGet, url, nil) +// if err != nil { +// t.Fatal(err) +// } +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// if resp.Header.Get(XFromCache) != "" { +// t.Errorf("Cache should not match") +// } +//} +// +//func TestDontStorePartialRangeInCache(t *testing.T) { +// resetTest(t) +// { +// req, err := http.NewRequest("GET", s.server.URL+"/range", nil) +// if err != nil { +// t.Fatal(err) +// } +// req.Header.Set("range", "bytes=4-9") +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// var buf bytes.Buffer +// _, err = io.Copy(&buf, resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// err = resp.Body.Close() +// if err != nil { +// t.Fatal(err) +// } +// if got, want := buf.String(), " text "; got != want { +// t.Errorf("got %q, want %q", got, want) +// } +// if resp.StatusCode != http.StatusPartialContent { +// t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) +// } +// } +// { +// req, err := http.NewRequest("GET", s.server.URL+"/range", nil) +// if err != nil { +// t.Fatal(err) +// } +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// var buf bytes.Buffer +// _, err = io.Copy(&buf, resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// err = resp.Body.Close() +// if err != nil { +// t.Fatal(err) +// } +// if got, want := buf.String(), "Some text content"; got != want { +// t.Errorf("got %q, want %q", got, want) +// } +// if resp.StatusCode != http.StatusOK { +// t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) +// } +// if resp.Header.Get(XFromCache) != "" { +// t.Error("XFromCache header isn't blank") +// } +// } +// { +// req, err := http.NewRequest("GET", s.server.URL+"/range", nil) +// if err != nil { +// t.Fatal(err) +// } +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// var buf bytes.Buffer +// _, err = io.Copy(&buf, resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// err = resp.Body.Close() +// if err != nil { +// t.Fatal(err) +// } +// if got, want := buf.String(), "Some text content"; got != want { +// t.Errorf("got %q, want %q", got, want) +// } +// if resp.StatusCode != http.StatusOK { +// t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) +// } +// if resp.Header.Get(XFromCache) != "1" { +// t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) +// } +// } +// { +// req, err := http.NewRequest("GET", s.server.URL+"/range", nil) +// if err != nil { +// t.Fatal(err) +// } +// req.Header.Set("range", "bytes=4-9") +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// var buf bytes.Buffer +// _, err = io.Copy(&buf, resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// err = resp.Body.Close() +// if err != nil { +// t.Fatal(err) +// } +// if got, want := buf.String(), " text "; got != want { +// t.Errorf("got %q, want %q", got, want) +// } +// if resp.StatusCode != http.StatusPartialContent { +// t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) +// } +// } +//} +// +//func TestCacheOnlyIfBodyRead(t *testing.T) { +// resetTest(t) +// { +// req, err := http.NewRequest("GET", s.server.URL, nil) +// if err != nil { +// t.Fatal(err) +// } +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// // We do not read the body +// resp.Body.Close() +// } +// { +// req, err := http.NewRequest("GET", s.server.URL, nil) +// if err != nil { +// t.Fatal(err) +// } +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatalf("XFromCache header isn't blank") +// } +// } +//} +// +//func TestOnlyReadBodyOnDemand(t *testing.T) { +// resetTest(t) +// +// req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil) +// if err != nil { +// t.Fatal(err) +// } +// resp, err := s.client.Do(req) // This shouldn't hang forever. +// if err != nil { +// t.Fatal(err) +// } +// buf := make([]byte, 10) // Only partially read the body. +// _, err = resp.Body.Read(buf) +// if err != nil { +// t.Fatal(err) +// } +// resp.Body.Close() +//} +// +//func TestGetOnlyIfCachedHit(t *testing.T) { +// resetTest(t) +// { +// req, err := http.NewRequest("GET", s.server.URL, nil) +// if err != nil { +// t.Fatal(err) +// } +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// } +// { +// req, err := http.NewRequest("GET", s.server.URL, nil) +// if err != nil { +// t.Fatal(err) +// } +// req.Header.Add("cache-control", "only-if-cached") +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "1" { +// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) +// } +// if resp.StatusCode != http.StatusOK { +// t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) +// } +// } +//} +// +//func TestGetOnlyIfCachedMiss(t *testing.T) { +// resetTest(t) +// req, err := http.NewRequest("GET", s.server.URL, nil) +// if err != nil { +// t.Fatal(err) +// } +// req.Header.Add("cache-control", "only-if-cached") +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// if resp.StatusCode != http.StatusGatewayTimeout { +// t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) +// } +//} +// +//func TestGetNoStoreRequest(t *testing.T) { +// resetTest(t) +// req, err := http.NewRequest("GET", s.server.URL, nil) +// if err != nil { +// t.Fatal(err) +// } +// req.Header.Add("Cache-Control", "no-store") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +//} +// +//func TestGetNoStoreResponse(t *testing.T) { +// resetTest(t) +// req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) +// if err != nil { +// t.Fatal(err) +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +//} +// +//func TestGetWithEtag(t *testing.T) { +// resetTest(t) +// req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) +// if err != nil { +// t.Fatal(err) +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "1" { +// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) +// } +// // additional assertions to verify that 304 response is converted properly +// if resp.StatusCode != http.StatusOK { +// t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) +// } +// if _, ok := resp.Header["Connection"]; ok { +// t.Fatalf("Connection header isn't absent") +// } +// } +//} +// +//func TestGetWithLastModified(t *testing.T) { +// resetTest(t) +// req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) +// if err != nil { +// t.Fatal(err) +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "1" { +// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) +// } +// } +//} +// +//func TestGetWithVary(t *testing.T) { +// resetTest(t) +// req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) +// if err != nil { +// t.Fatal(err) +// } +// req.Header.Set("Accept", "text/plain") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get("Vary") != "Accept" { +// t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary")) +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "1" { +// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) +// } +// } +// req.Header.Set("Accept", "text/html") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +// req.Header.Set("Accept", "") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +//} +// +//func TestGetWithDoubleVary(t *testing.T) { +// resetTest(t) +// req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) +// if err != nil { +// t.Fatal(err) +// } +// req.Header.Set("Accept", "text/plain") +// req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get("Vary") == "" { +// t.Fatalf(`Vary header is blank`) +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "1" { +// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) +// } +// } +// req.Header.Set("Accept-Language", "") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +// req.Header.Set("Accept-Language", "da") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +//} +// +//func TestGetWith2VaryHeaders(t *testing.T) { +// resetTest(t) +// // Tests that multiple Vary headers' comma-separated lists are +// // merged. See https://github.com/gregjones/httpcache/issues/27. +// const ( +// accept = "text/plain" +// acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" +// ) +// req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) +// if err != nil { +// t.Fatal(err) +// } +// req.Header.Set("Accept", accept) +// req.Header.Set("Accept-Language", acceptLanguage) +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get("Vary") == "" { +// t.Fatalf(`Vary header is blank`) +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "1" { +// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) +// } +// } +// req.Header.Set("Accept-Language", "") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +// req.Header.Set("Accept-Language", "da") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +// req.Header.Set("Accept-Language", acceptLanguage) +// req.Header.Set("Accept", "") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// } +// req.Header.Set("Accept", "image/png") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "" { +// t.Fatal("XFromCache header isn't blank") +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "1" { +// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) +// } +// } +//} +// +//func TestGetVaryUnused(t *testing.T) { +// resetTest(t) +// req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) +// if err != nil { +// t.Fatal(err) +// } +// req.Header.Set("Accept", "text/plain") +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get("Vary") == "" { +// t.Fatalf(`Vary header is blank`) +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "1" { +// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) +// } +// } +//} +// +//func TestUpdateFields(t *testing.T) { +// resetTest(t) +// req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) +// if err != nil { +// t.Fatal(err) +// } +// var counter, counter2 string +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// counter = resp.Header.Get("x-counter") +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.Header.Get(XFromCache) != "1" { +// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) +// } +// counter2 = resp.Header.Get("x-counter") +// } +// if counter == counter2 { +// t.Fatalf(`both "x-counter" values are equal: %v %v`, counter, counter2) +// } +//} +// +//// This tests the fix for https://github.com/gregjones/httpcache/issues/74. +//// Previously, after validating a cached response, its StatusCode +//// was incorrectly being replaced. +//func TestCachedErrorsKeepStatus(t *testing.T) { +// resetTest(t) +// req, err := http.NewRequest("GET", s.server.URL+"/cachederror", nil) +// if err != nil { +// t.Fatal(err) +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// _, _ = io.Copy(ioutil.Discard, resp.Body) +// } +// { +// resp, err := s.client.Do(req) +// if err != nil { +// t.Fatal(err) +// } +// defer resp.Body.Close() +// if resp.StatusCode != http.StatusNotFound { +// t.Fatalf("Status code isn't 404: %d", resp.StatusCode) +// } +// } +//} +// +//func TestParseCacheControl(t *testing.T) { +// resetTest(t) +// h := http.Header{} +// for range parseCacheControl(h) { +// t.Fatal("cacheControl should be empty") +// } +// +// h.Set("cache-control", "no-cache") +// { +// cc := parseCacheControl(h) +// if _, ok := cc["foo"]; ok { +// t.Error(`Value "foo" shouldn't exist`) +// } +// noCache, ok := cc["no-cache"] +// if !ok { +// t.Fatalf(`"no-cache" value isn't set`) +// } +// if noCache != "" { +// t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) +// } +// } +// h.Set("cache-control", "no-cache, max-age=3600") +// { +// cc := parseCacheControl(h) +// noCache, ok := cc["no-cache"] +// if !ok { +// t.Fatalf(`"no-cache" value isn't set`) +// } +// if noCache != "" { +// t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) +// } +// if cc["max-age"] != "3600" { +// t.Fatalf(`"max-age" value isn't "3600": %v`, cc["max-age"]) +// } +// } +//} +// +//func TestNoCacheRequestExpiration(t *testing.T) { +// resetTest(t) +// respHeaders := http.Header{} +// respHeaders.Set("Cache-Control", "max-age=7200") +// +// reqHeaders := http.Header{} +// reqHeaders.Set("Cache-Control", "no-cache") +// if getFreshness(respHeaders, reqHeaders) != transparent { +// t.Fatal("freshness isn't transparent") +// } +//} +// +//func TestNoCacheResponseExpiration(t *testing.T) { +// resetTest(t) +// respHeaders := http.Header{} +// respHeaders.Set("Cache-Control", "no-cache") +// respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") +// +// reqHeaders := http.Header{} +// if getFreshness(respHeaders, reqHeaders) != stale { +// t.Fatal("freshness isn't stale") +// } +//} +// +//func TestReqMustRevalidate(t *testing.T) { +// resetTest(t) +// // not paying attention to request setting max-stale means never returning stale +// // responses, so always acting as if must-revalidate is set +// respHeaders := http.Header{} +// +// reqHeaders := http.Header{} +// reqHeaders.Set("Cache-Control", "must-revalidate") +// if getFreshness(respHeaders, reqHeaders) != stale { +// t.Fatal("freshness isn't stale") +// } +//} +// +//func TestRespMustRevalidate(t *testing.T) { +// resetTest(t) +// respHeaders := http.Header{} +// respHeaders.Set("Cache-Control", "must-revalidate") +// +// reqHeaders := http.Header{} +// if getFreshness(respHeaders, reqHeaders) != stale { +// t.Fatal("freshness isn't stale") +// } +//} +// +//func TestFreshExpiration(t *testing.T) { +// resetTest(t) +// now := time.Now() +// respHeaders := http.Header{} +// respHeaders.Set("date", now.Format(time.RFC1123)) +// respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) +// +// reqHeaders := http.Header{} +// if getFreshness(respHeaders, reqHeaders) != fresh { +// t.Fatal("freshness isn't fresh") +// } +// +// clock = &fakeClock{elapsed: 3 * time.Second} +// if getFreshness(respHeaders, reqHeaders) != stale { +// t.Fatal("freshness isn't stale") +// } +//} +// +//func TestMaxAge(t *testing.T) { +// resetTest(t) +// now := time.Now() +// respHeaders := http.Header{} +// respHeaders.Set("date", now.Format(time.RFC1123)) +// respHeaders.Set("cache-control", "max-age=2") +// +// reqHeaders := http.Header{} +// if getFreshness(respHeaders, reqHeaders) != fresh { +// t.Fatal("freshness isn't fresh") +// } +// +// clock = &fakeClock{elapsed: 3 * time.Second} +// if getFreshness(respHeaders, reqHeaders) != stale { +// t.Fatal("freshness isn't stale") +// } +//} +// +//func TestMaxAgeZero(t *testing.T) { +// resetTest(t) +// now := time.Now() +// respHeaders := http.Header{} +// respHeaders.Set("date", now.Format(time.RFC1123)) +// respHeaders.Set("cache-control", "max-age=0") +// +// reqHeaders := http.Header{} +// if getFreshness(respHeaders, reqHeaders) != stale { +// t.Fatal("freshness isn't stale") +// } +//} +// +//func TestBothMaxAge(t *testing.T) { +// resetTest(t) +// now := time.Now() +// respHeaders := http.Header{} +// respHeaders.Set("date", now.Format(time.RFC1123)) +// respHeaders.Set("cache-control", "max-age=2") +// +// reqHeaders := http.Header{} +// reqHeaders.Set("cache-control", "max-age=0") +// if getFreshness(respHeaders, reqHeaders) != stale { +// t.Fatal("freshness isn't stale") +// } +//} +// +//func TestMinFreshWithExpires(t *testing.T) { +// resetTest(t) +// now := time.Now() +// respHeaders := http.Header{} +// respHeaders.Set("date", now.Format(time.RFC1123)) +// respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) +// +// reqHeaders := http.Header{} +// reqHeaders.Set("cache-control", "min-fresh=1") +// if getFreshness(respHeaders, reqHeaders) != fresh { +// t.Fatal("freshness isn't fresh") +// } +// +// reqHeaders = http.Header{} +// reqHeaders.Set("cache-control", "min-fresh=2") +// if getFreshness(respHeaders, reqHeaders) != stale { +// t.Fatal("freshness isn't stale") +// } +//} +// +//func TestEmptyMaxStale(t *testing.T) { +// resetTest(t) +// now := time.Now() +// respHeaders := http.Header{} +// respHeaders.Set("date", now.Format(time.RFC1123)) +// respHeaders.Set("cache-control", "max-age=20") +// +// reqHeaders := http.Header{} +// reqHeaders.Set("cache-control", "max-stale") +// clock = &fakeClock{elapsed: 10 * time.Second} +// if getFreshness(respHeaders, reqHeaders) != fresh { +// t.Fatal("freshness isn't fresh") +// } +// +// clock = &fakeClock{elapsed: 60 * time.Second} +// if getFreshness(respHeaders, reqHeaders) != fresh { +// t.Fatal("freshness isn't fresh") +// } +//} +// +//func TestMaxStaleValue(t *testing.T) { +// resetTest(t) +// now := time.Now() +// respHeaders := http.Header{} +// respHeaders.Set("date", now.Format(time.RFC1123)) +// respHeaders.Set("cache-control", "max-age=10") +// +// reqHeaders := http.Header{} +// reqHeaders.Set("cache-control", "max-stale=20") +// clock = &fakeClock{elapsed: 5 * time.Second} +// if getFreshness(respHeaders, reqHeaders) != fresh { +// t.Fatal("freshness isn't fresh") +// } +// +// clock = &fakeClock{elapsed: 15 * time.Second} +// if getFreshness(respHeaders, reqHeaders) != fresh { +// t.Fatal("freshness isn't fresh") +// } +// +// clock = &fakeClock{elapsed: 30 * time.Second} +// if getFreshness(respHeaders, reqHeaders) != stale { +// t.Fatal("freshness isn't stale") +// } +//} +// +//func containsHeader(headers []string, header string) bool { +// for _, v := range headers { +// if http.CanonicalHeaderKey(v) == http.CanonicalHeaderKey(header) { +// return true +// } +// } +// return false +//} +// +//func TestGetEndToEndHeaders(t *testing.T) { +// resetTest(t) +// var ( +// headers http.Header +// end2end []string +// ) +// +// headers = http.Header{} +// headers.Set("content-type", "text/html") +// headers.Set("te", "deflate") +// +// end2end = getEndToEndHeaders(headers) +// if !containsHeader(end2end, "content-type") { +// t.Fatal(`doesn't contain "content-type" header`) +// } +// if containsHeader(end2end, "te") { +// t.Fatal(`doesn't contain "te" header`) +// } +// +// headers = http.Header{} +// headers.Set("connection", "content-type") +// headers.Set("content-type", "text/csv") +// headers.Set("te", "deflate") +// end2end = getEndToEndHeaders(headers) +// if containsHeader(end2end, "connection") { +// t.Fatal(`doesn't contain "connection" header`) +// } +// if containsHeader(end2end, "content-type") { +// t.Fatal(`doesn't contain "content-type" header`) +// } +// if containsHeader(end2end, "te") { +// t.Fatal(`doesn't contain "te" header`) +// } +// +// headers = http.Header{} +// end2end = getEndToEndHeaders(headers) +// if len(end2end) != 0 { +// t.Fatal(`non-zero end2end headers`) +// } +// +// headers = http.Header{} +// headers.Set("connection", "content-type") +// end2end = getEndToEndHeaders(headers) +// if len(end2end) != 0 { +// t.Fatal(`non-zero end2end headers`) +// } +//} +// +//type transportMock struct { +// response *http.Response +// err error +//} +// +//func (t transportMock) RoundTrip(req *http.Request) (resp *http.Response, err error) { +// return t.response, t.err +//} + +// +//func TestStaleIfErrorRequest(t *testing.T) { +// resetTest(t) +// now := time.Now() +// tmock := transportMock{ +// response: &http.Response{ +// Status: http.StatusText(http.StatusOK), +// StatusCode: http.StatusOK, +// Header: http.Header{ +// "Date": []string{now.Format(time.RFC1123)}, +// "Cache-Control": []string{"no-cache"}, +// }, +// Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), +// }, +// err: nil, +// } +// tp := newTestTransport(t.TempDir()) +// tp.Transport = &tmock +// +// // First time, response is cached on success +// r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) +// r.Header.Set("Cache-Control", "stale-if-error") +// resp, err := tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// +// // On failure, response is returned from the cache +// tmock.response = nil +// tmock.err = errors.New("some error") +// resp, err = tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +//} +// +//func TestStaleIfErrorRequestLifetime(t *testing.T) { +// resetTest(t) +// now := time.Now() +// tmock := transportMock{ +// response: &http.Response{ +// Status: http.StatusText(http.StatusOK), +// StatusCode: http.StatusOK, +// Header: http.Header{ +// "Date": []string{now.Format(time.RFC1123)}, +// "Cache-Control": []string{"no-cache"}, +// }, +// Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), +// }, +// err: nil, +// } +// tp := newTestTransport(t.TempDir()) +// tp.Transport = &tmock +// +// // First time, response is cached on success +// r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) +// r.Header.Set("Cache-Control", "stale-if-error=100") +// resp, err := tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// +// // On failure, response is returned from the cache +// tmock.response = nil +// tmock.err = errors.New("some error") +// resp, err = tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +// +// // Same for http errors +// tmock.response = &http.Response{StatusCode: http.StatusInternalServerError} +// tmock.err = nil +// resp, err = tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +// +// // If failure last more than max stale, error is returned +// clock = &fakeClock{elapsed: 200 * time.Second} +// _, err = tp.RoundTrip(r) +// if err != tmock.err { +// t.Fatalf("got err %v, want %v", err, tmock.err) +// } +//} +// +//func TestStaleIfErrorResponse(t *testing.T) { +// resetTest(t) +// now := time.Now() +// tmock := transportMock{ +// response: &http.Response{ +// Status: http.StatusText(http.StatusOK), +// StatusCode: http.StatusOK, +// Header: http.Header{ +// "Date": []string{now.Format(time.RFC1123)}, +// "Cache-Control": []string{"no-cache, stale-if-error"}, +// }, +// Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), +// }, +// err: nil, +// } +// tp := newTestTransport(t.TempDir()) +// tp.Transport = &tmock +// +// // First time, response is cached on success +// r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) +// resp, err := tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// +// // On failure, response is returned from the cache +// tmock.response = nil +// tmock.err = errors.New("some error") +// resp, err = tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +//} +// +//func TestStaleIfErrorResponseLifetime(t *testing.T) { +// resetTest(t) +// now := time.Now() +// tmock := transportMock{ +// response: &http.Response{ +// Status: http.StatusText(http.StatusOK), +// StatusCode: http.StatusOK, +// Header: http.Header{ +// "Date": []string{now.Format(time.RFC1123)}, +// "Cache-Control": []string{"no-cache, stale-if-error=100"}, +// }, +// Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), +// }, +// err: nil, +// } +// tp := newTestTransport(t.TempDir()) +// tp.Transport = &tmock +// +// // First time, response is cached on success +// r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) +// resp, err := tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// +// // On failure, response is returned from the cache +// tmock.response = nil +// tmock.err = errors.New("some error") +// resp, err = tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +// +// // If failure last more than max stale, error is returned +// clock = &fakeClock{elapsed: 200 * time.Second} +// _, err = tp.RoundTrip(r) +// if err != tmock.err { +// t.Fatalf("got err %v, want %v", err, tmock.err) +// } +//} +// +//// This tests the fix for https://github.com/gregjones/httpcache/issues/74. +//// Previously, after a stale response was used after encountering an error, +//// its StatusCode was being incorrectly replaced. +//func TestStaleIfErrorKeepsStatus(t *testing.T) { +// resetTest(t) +// now := time.Now() +// tmock := transportMock{ +// response: &http.Response{ +// Status: http.StatusText(http.StatusNotFound), +// StatusCode: http.StatusNotFound, +// Header: http.Header{ +// "Date": []string{now.Format(time.RFC1123)}, +// "Cache-Control": []string{"no-cache"}, +// }, +// Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), +// }, +// err: nil, +// } +// tp := newTestTransport(t.TempDir()) +// tp.Transport = &tmock +// +// // First time, response is cached on success +// r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) +// r.Header.Set("Cache-Control", "stale-if-error") +// resp, err := tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +// _, err = ioutil.ReadAll(resp.Body) +// if err != nil { +// t.Fatal(err) +// } +// +// // On failure, response is returned from the cache +// tmock.response = nil +// tmock.err = errors.New("some error") +// resp, err = tp.RoundTrip(r) +// if err != nil { +// t.Fatal(err) +// } +// if resp == nil { +// t.Fatal("resp is nil") +// } +// if resp.StatusCode != http.StatusNotFound { +// t.Fatalf("Status wasn't 404: %d", resp.StatusCode) +// } +//} +// +//// Test that http.Client.Timeout is respected when cache transport is used. +//// That is so as long as request cancellation is propagated correctly. +//// In the past, that required CancelRequest to be implemented correctly, +//// but modern http.Client uses Request.Cancel (or request context) instead, +//// so we don't have to do anything. +//func TestClientTimeout(t *testing.T) { +// if testing.Short() { +// t.Skip("skipping timeout test in short mode") // Because it takes at least 3 seconds to run. +// } +// resetTest(t) +// +// client := &http.Client{ +// Transport: newTestTransport(t.TempDir()), +// Timeout: time.Second, +// } +// started := time.Now() +// resp, err := client.Get(s.server.URL + "/3seconds") +// taken := time.Since(started) +// if err == nil { +// t.Error("got nil error, want timeout error") +// } +// if resp != nil { +// t.Error("got non-nil resp, want nil resp") +// } +// if taken >= 2*time.Second { +// t.Error("client.Do took 2+ seconds, want < 2 seconds") +// } +//} diff --git a/libsq/core/ioz/httpcache/httpz.go b/libsq/core/ioz/httpcache/httpz.go index 8e5337684..c4507e078 100644 --- a/libsq/core/ioz/httpcache/httpz.go +++ b/libsq/core/ioz/httpcache/httpz.go @@ -2,26 +2,23 @@ package httpcache import ( "bufio" + "bytes" + "errors" "fmt" "io" "net/http" "net/textproto" "strconv" "strings" + "time" ) -// ReadResponse is a copy of http.ReadResponse, but with the option -// to read only the response header, and not the body. When only reading -// the header, note that resp.Body will be nil, and that the resp is -// generally not functional. -func ReadResponse(r *bufio.Reader, req *http.Request, headerOnly bool) (*http.Response, error) { - if !headerOnly { - return http.ReadResponse(r, req) - } +// readResponseHeader is a fork of http.ReadResponse that reads only the +// header from req and not the body. Note that resp.Body will be nil, and +// that the resp object is borked for general use. +func readResponseHeader(r *bufio.Reader, req *http.Request) (resp *http.Response, err error) { tp := textproto.NewReader(r) - resp := &http.Response{ - Request: req, - } + resp = &http.Response{Request: req} // Parse the first line of the response. line, err := tp.ReadLine() @@ -62,11 +59,6 @@ func ReadResponse(r *bufio.Reader, req *http.Request, headerOnly bool) (*http.Re fixPragmaCacheControl(resp.Header) - //err = readTransfer(resp, r) - //if err != nil { - // return nil, err - //} - return resp, nil } @@ -86,3 +78,308 @@ func fixPragmaCacheControl(header http.Header) { } func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } + +// errNoDateHeader indicates that the HTTP headers contained no Date header. +var errNoDateHeader = errors.New("no Date header") + +// Date parses and returns the value of the Date header. +func getDate(respHeaders http.Header) (date time.Time, err error) { + dateHeader := respHeaders.Get("date") + if dateHeader == "" { + err = errNoDateHeader + return + } + + return time.Parse(time.RFC1123, dateHeader) +} + +type realClock struct{} + +func (c *realClock) since(d time.Time) time.Duration { + return time.Since(d) +} + +type timer interface { + since(d time.Time) time.Duration +} + +var clock timer = &realClock{} + +// getFreshness will return one of fresh/stale/transparent based on the cache-control +// values of the request and the response +// +// fresh indicates the response can be returned +// stale indicates that the response needs validating before it is returned +// transparent indicates the response should not be used to fulfil the request +// +// Because this is only a private cache, 'public' and 'private' in cache-control aren't +// significant. Similarly, smax-age isn't used. +func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { + respCacheControl := parseCacheControl(respHeaders) + reqCacheControl := parseCacheControl(reqHeaders) + if _, ok := reqCacheControl["no-cache"]; ok { + return transparent + } + if _, ok := respCacheControl["no-cache"]; ok { + return stale + } + if _, ok := reqCacheControl["only-if-cached"]; ok { + return fresh + } + + date, err := getDate(respHeaders) + if err != nil { + return stale + } + currentAge := clock.since(date) + + var lifetime time.Duration + var zeroDuration time.Duration + + // If a response includes both an Expires header and a max-age directive, + // the max-age directive overrides the Expires header, even if the Expires header is more restrictive. + if maxAge, ok := respCacheControl["max-age"]; ok { + lifetime, err = time.ParseDuration(maxAge + "s") + if err != nil { + lifetime = zeroDuration + } + } else { + expiresHeader := respHeaders.Get("Expires") + if expiresHeader != "" { + expires, err := time.Parse(time.RFC1123, expiresHeader) + if err != nil { + lifetime = zeroDuration + } else { + lifetime = expires.Sub(date) + } + } + } + + if maxAge, ok := reqCacheControl["max-age"]; ok { + // the client is willing to accept a response whose age is no greater than the specified time in seconds + lifetime, err = time.ParseDuration(maxAge + "s") + if err != nil { + lifetime = zeroDuration + } + } + if minfresh, ok := reqCacheControl["min-fresh"]; ok { + // the client wants a response that will still be fresh for at least the specified number of seconds. + minfreshDuration, err := time.ParseDuration(minfresh + "s") + if err == nil { + currentAge = time.Duration(currentAge + minfreshDuration) + } + } + + if maxstale, ok := reqCacheControl["max-stale"]; ok { + // Indicates that the client is willing to accept a response that has exceeded its expiration time. + // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded + // its expiration time by no more than the specified number of seconds. + // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age. + // + // Responses served only because of a max-stale value are supposed to have a Warning header added to them, + // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different + // return-value available here. + if maxstale == "" { + return fresh + } + maxstaleDuration, err := time.ParseDuration(maxstale + "s") + if err == nil { + currentAge = time.Duration(currentAge - maxstaleDuration) + } + } + + if lifetime > currentAge { + return fresh + } + + return stale +} + +// Returns true if either the request or the response includes the stale-if-error +// cache control extension: https://tools.ietf.org/html/rfc5861 +func canStaleOnError(respHeaders, reqHeaders http.Header) bool { + respCacheControl := parseCacheControl(respHeaders) + reqCacheControl := parseCacheControl(reqHeaders) + + var err error + lifetime := time.Duration(-1) + + if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok { + if staleMaxAge != "" { + lifetime, err = time.ParseDuration(staleMaxAge + "s") + if err != nil { + return false + } + } else { + return true + } + } + if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok { + if staleMaxAge != "" { + lifetime, err = time.ParseDuration(staleMaxAge + "s") + if err != nil { + return false + } + } else { + return true + } + } + + if lifetime >= 0 { + date, err := getDate(respHeaders) + if err != nil { + return false + } + currentAge := clock.since(date) + if lifetime > currentAge { + return true + } + } + + return false +} + +func getEndToEndHeaders(respHeaders http.Header) []string { + // These headers are always hop-by-hop + hopByHopHeaders := map[string]struct{}{ + "Connection": {}, + "Keep-Alive": {}, + "Proxy-Authenticate": {}, + "Proxy-Authorization": {}, + "Te": {}, + "Trailers": {}, + "Transfer-Encoding": {}, + "Upgrade": {}, + } + + for _, extra := range strings.Split(respHeaders.Get("connection"), ",") { + // any header listed in connection, if present, is also considered hop-by-hop + if strings.Trim(extra, " ") != "" { + hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{} + } + } + endToEndHeaders := []string{} + for respHeader := range respHeaders { + if _, ok := hopByHopHeaders[respHeader]; !ok { + endToEndHeaders = append(endToEndHeaders, respHeader) + } + } + return endToEndHeaders +} + +func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) { + if _, ok := respCacheControl["no-store"]; ok { + return false + } + if _, ok := reqCacheControl["no-store"]; ok { + return false + } + return true +} + +func newGatewayTimeoutResponse(req *http.Request) *http.Response { + var braw bytes.Buffer + braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n") + resp, err := http.ReadResponse(bufio.NewReader(&braw), req) + if err != nil { + panic(err) + } + return resp +} + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map. +// (This function copyright goauth2 authors: https://code.google.com/p/goauth2) +func cloneRequest(r *http.Request) *http.Request { + // shallow copy of the struct + r2 := new(http.Request) + *r2 = *r + if ctx := r.Context(); ctx != nil { + r2 = r2.WithContext(ctx) + } + // deep copy of the Header + r2.Header = make(http.Header) + for k, s := range r.Header { + r2.Header[k] = s + } + return r2 +} + +type cacheControl map[string]string + +func parseCacheControl(headers http.Header) cacheControl { + cc := cacheControl{} + ccHeader := headers.Get("Cache-Control") + for _, part := range strings.Split(ccHeader, ",") { + part = strings.Trim(part, " ") + if part == "" { + continue + } + if strings.ContainsRune(part, '=') { + keyval := strings.Split(part, "=") + cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",") + } else { + cc[part] = "" + } + } + return cc +} + +// headerAllCommaSepValues returns all comma-separated values (each +// with whitespace trimmed) for header name in headers. According to +// Section 4.2 of the HTTP/1.1 spec +// (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2), +// values from multiple occurrences of a header should be concatenated, if +// the header's value is a comma-separated list. +func headerAllCommaSepValues(headers http.Header, name string) []string { + var vals []string + for _, val := range headers[http.CanonicalHeaderKey(name)] { + fields := strings.Split(val, ",") + for i, f := range fields { + fields[i] = strings.TrimSpace(f) + } + vals = append(vals, fields...) + } + return vals +} + +// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF +// handler with a full copy of the content read from R when EOF is +// reached. +type cachingReadCloser struct { + // Underlying ReadCloser. + R io.ReadCloser + // OnEOF is called with a copy of the content of R when EOF is reached. + OnEOF func(io.Reader) + + buf bytes.Buffer // buf stores a copy of the content of R. +} + +// Read reads the next len(p) bytes from R or until R is drained. The +// return value n is the number of bytes read. If R has no data to +// return, err is io.EOF and OnEOF is called with a full copy of what +// has been read so far. +func (r *cachingReadCloser) Read(p []byte) (n int, err error) { + n, err = r.R.Read(p) + r.buf.Write(p[:n]) + if err == io.EOF { + r.OnEOF(bytes.NewReader(r.buf.Bytes())) + } + return n, err +} + +func (r *cachingReadCloser) Close() error { + return r.R.Close() +} + +// varyMatches will return false unless all the cached values for the +// headers listed in Vary match the new request +func varyMatches(cachedResp *http.Response, req *http.Request) bool { + for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { + header = http.CanonicalHeaderKey(header) + if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) { + return false + } + } + return true +} diff --git a/libsq/core/ioz/httpcacheworking/LICENSE.txt b/libsq/core/ioz/httpcacheworking/LICENSE.txt deleted file mode 100644 index 81316beb0..000000000 --- a/libsq/core/ioz/httpcacheworking/LICENSE.txt +++ /dev/null @@ -1,7 +0,0 @@ -Copyright © 2012 Greg Jones (greg.jones@gmail.com) - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/libsq/core/ioz/httpcacheworking/README.md b/libsq/core/ioz/httpcacheworking/README.md deleted file mode 100644 index 58cda222f..000000000 --- a/libsq/core/ioz/httpcacheworking/README.md +++ /dev/null @@ -1,42 +0,0 @@ -# httpcache - -[![GoDoc](https://godoc.org/github.com/bitcomplete/httpcache?status.svg)](https://godoc.org/github.com/bitcomplete/httpcache) - -Package httpcache provides a http.RoundTripper implementation that works as a -mostly [RFC 7234](https://tools.ietf.org/html/rfc7234) compliant cache for http -responses. This incarnation of the library is an active fork of -[github.com/gregjones/httpcache](https://github.com/gregjones/httpcache) which -is unmaintained. - -It is only suitable for use as a 'private' cache (i.e. for a web-browser or an -API-client and not for a shared proxy). - -## Cache Backends - -- The built-in 'memory' cache stores responses in an in-memory map. - - [`github.com/bitcomplete/httpcache/diskcache`](https://github.com/bitcomplete/httpcache/tree/master/diskcache) - provides a filesystem-backed cache using the - [diskv](https://github.com/peterbourgon/diskv) library. - - [`github.com/bitcomplete/httpcache/memcache`](https://github.com/bitcomplete/httpcache/tree/master/memcache) - provides memcache implementations, for both App Engine and 'normal' memcache - servers. - - [`sourcegraph.com/sourcegraph/s3cache`](https://sourcegraph.com/github.com/sourcegraph/s3cache) - uses Amazon S3 for storage. - - [`github.com/bitcomplete/httpcache/leveldbcache`](https://github.com/bitcomplete/httpcache/tree/master/leveldbcache) - provides a filesystem-backed cache using - [leveldb](https://github.com/syndtr/goleveldb/leveldb). - - [`github.com/die-net/lrucache`](https://github.com/die-net/lrucache) provides an - in-memory cache that will evict least-recently used entries. - - [`github.com/die-net/lrucache/twotier`](https://github.com/die-net/lrucache/tree/master/twotier) - allows caches to be combined, for example to use lrucache above with a - persistent disk-cache. - - [`github.com/birkelund/boltdbcache`](https://github.com/birkelund/boltdbcache) - provides a BoltDB implementation (based on the - [bbolt](https://github.com/coreos/bbolt) fork). - -If you implement any other backend and wish it to be linked here, please send a -PR editing this file. - -## License - -- [MIT License](LICENSE.txt) diff --git a/libsq/core/ioz/httpcacheworking/httpcache.go b/libsq/core/ioz/httpcacheworking/httpcache.go deleted file mode 100644 index b8bbb9ef9..000000000 --- a/libsq/core/ioz/httpcacheworking/httpcache.go +++ /dev/null @@ -1,717 +0,0 @@ -// Package httpcacheworking provides a http.RoundTripper implementation that -// works as a mostly RFC-compliant cache for http responses. -// -// FIXME: move httpcache to internal/httpcache, because its use -// is so specialized? -// -// Acknowledgement: This package is a heavily customized fork -// of https://github.com/gregjones/httpcache, via bitcomplete/httpcache. -package httpcacheworking - -import ( - "bufio" - "bytes" - "errors" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "io" - "io/ioutil" - "net/http" - "os" - "strings" - "time" -) - -const ( - stale = iota - fresh - transparent - // XFromCache is the header added to responses that are returned from the cache - XFromCache = "X-From-Cache" -) - -// TransportOpt is a configuration option for creating a new Transport -type TransportOpt func(t *Transport) - -// MarkCachedResponsesOpt configures a transport by setting MarkCachedResponses to true -func MarkCachedResponsesOpt(markCachedResponses bool) TransportOpt { - return func(t *Transport) { - t.MarkCachedResponses = markCachedResponses - } -} - -// Transport is an implementation of http.RoundTripper that will return values from a cache -// where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) -// to repeated requests allowing servers to return 304 / Not Modified -type Transport struct { - // The RoundTripper interface actually used to make requests - // If nil, http.DefaultTransport is used - Transport http.RoundTripper - - RespCache *RespCache - - // MarkCachedResponses, if true, indicates that responses returned from the - // cache will be given an extra header, X-From-Cache - MarkCachedResponses bool -} - -// NewTransport returns a new Transport with the provided Cache and options. If -// KeyFunc is not specified in opts then DefaultKeyFunc is used. -func NewTransport(rc *RespCache, opts ...TransportOpt) *Transport { - t := &Transport{ - RespCache: rc, - MarkCachedResponses: true, - } - for _, opt := range opts { - opt(t) - } - return t -} - -// Client returns an *http.Client that caches responses. -func (t *Transport) Client() *http.Client { - return &http.Client{Transport: t} -} - -// varyMatches will return false unless all the cached values for the -// headers listed in Vary match the new request -func varyMatches(cachedResp *http.Response, req *http.Request) bool { - for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { - header = http.CanonicalHeaderKey(header) - if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) { - return false - } - } - return true -} - -// IsCached returns true if there is a cache entry for req. This does not -// guarantee that the cache entry is fresh. -func (t *Transport) IsCached(req *http.Request) bool { - return t.RespCache.Exists(req) -} - -// IsFresh returns true if there is a fresh cache entry for req. -func (t *Transport) IsFresh(req *http.Request) bool { - ctx := req.Context() - log := lg.FromContext(ctx) - - if !isCacheable(req) { - return false - } - - if !t.RespCache.Exists(req) { - return false - } - - fpHeader, _ := t.RespCache.Paths(req) - f, err := os.Open(fpHeader) - if err != nil { - log.Error("Failed to open cached response header file", lga.File, fpHeader, lga.Err, err) - return false - } - - defer lg.WarnIfCloseError(log, "Close cached response header", f) - - cachedResp, err := ReadResponse(bufio.NewReader(f), nil, true) - if err != nil { - log.Error("Failed to read cached response", lga.Err, err) - return false - } - - freshness := getFreshness(cachedResp.Header, req.Header) - return freshness == fresh -} - -func isCacheable(req *http.Request) bool { - return (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get("range") == "" -} - -type CallbackHandler struct { - HandleCached func(cachedFilepath string) error - HandleUncached func() (wc io.WriteCloser, errFn func(error), err error) - HandleError func(err error) -} - -func (t *Transport) Fetch(req *http.Request, cb CallbackHandler) { - ctx := req.Context() - log := lg.FromContext(ctx) - log.Info("Fetching download", lga.URL, req.URL.String()) - _ = log - _, fpBody := t.RespCache.Paths(req) - - if t.IsFresh(req) { - _ = cb.HandleCached(fpBody) - return - } - - var err error - cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" - var cachedResp *http.Response - if cacheable { - cachedResp, err = t.RespCache.Get(req.Context(), req) - } else { - // Need to invalidate an existing value - if err = t.RespCache.Delete(req.Context()); err != nil { - cb.HandleError(err) - return - } - } - - transport := t.Transport - if transport == nil { - transport = http.DefaultTransport - } - var resp *http.Response - if cacheable && cachedResp != nil && err == nil { - if t.MarkCachedResponses { - cachedResp.Header.Set(XFromCache, "1") - } - - if varyMatches(cachedResp, req) { - // Can only use cached value if the new request doesn't Vary significantly - freshness := getFreshness(cachedResp.Header, req.Header) - if freshness == fresh { - _ = cb.HandleCached(fpBody) - return - } - - if freshness == stale { - var req2 *http.Request - // Add validators if caller hasn't already done so - etag := cachedResp.Header.Get("etag") - if etag != "" && req.Header.Get("etag") == "" { - req2 = cloneRequest(req) - req2.Header.Set("if-none-match", etag) - } - lastModified := cachedResp.Header.Get("last-modified") - if lastModified != "" && req.Header.Get("last-modified") == "" { - if req2 == nil { - req2 = cloneRequest(req) - } - req2.Header.Set("if-modified-since", lastModified) - } - if req2 != nil { - req = req2 - } - } - } - - // FIXME: Use an http client here - resp, err = transport.RoundTrip(req) - if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { - // Replace the 304 response with the one from cache, but update with some new headers - endToEndHeaders := getEndToEndHeaders(resp.Header) - for _, header := range endToEndHeaders { - cachedResp.Header[header] = resp.Header[header] - } - resp = cachedResp - } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && - req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header) { - // In case of transport failure and stale-if-error activated, returns cached content - // when available - log.Warn("Returning cached response due to transport failure", lga.Err, err) - cb.HandleCached(fpBody) - return - } else { - if err != nil || resp.StatusCode != http.StatusOK { - lg.WarnIfError(log, msgDeleteCache, t.RespCache.Delete(req.Context())) - } - if err != nil { - cb.HandleError(err) - return - } - } - } else { - reqCacheControl := parseCacheControl(req.Header) - if _, ok := reqCacheControl["only-if-cached"]; ok { - resp = newGatewayTimeoutResponse(req) - } else { - resp, err = transport.RoundTrip(req) - if err != nil { - cb.HandleError(err) - return - } - } - } - - if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { - for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { - varyKey = http.CanonicalHeaderKey(varyKey) - fakeHeader := "X-Varied-" + varyKey - reqValue := req.Header.Get(varyKey) - if reqValue != "" { - resp.Header.Set(fakeHeader, reqValue) - } - } - switch req.Method { - //case "GET": - // // Delay caching until EOF is reached. - // resp.Body = &cachingReadCloser{ - // R: resp.Body, - // OnEOF: func(r io.Reader) { - // resp := *resp - // resp.Body = ioutil.NopCloser(r) - // if err := t.RespCache.Write(req.Context(), &resp, nil); err != nil { - // log.Error("failed to write download cache", lga.Err, err) - // } - // }, - // } - default: - copyWrtr, errFn, err := cb.HandleUncached() - if err != nil { - cb.HandleError(err) - return - } - - if err = t.RespCache.Write(req.Context(), resp, copyWrtr); err != nil { - log.Error("failed to write download cache", lga.Err, err) - errFn(err) - cb.HandleError(err) - } - return - } - } else { - lg.WarnIfError(log, "Delete resp cache", t.RespCache.Delete(req.Context())) - } - - // It's not cacheable, so we need to write it to the copyWrtr - copyWrtr, errFn, err := cb.HandleUncached() - if err != nil { - cb.HandleError(err) - return - } - cr := contextio.NewReader(ctx, resp.Body) - _, err = io.Copy(copyWrtr, cr) - if err != nil { - errFn(err) - cb.HandleError(err) - return - } - if err = copyWrtr.Close(); err != nil { - cb.HandleError(err) - return - } - - return -} - -// RoundTrip takes a Request and returns a Response -// -// If there is a fresh Response already in cache, then it will be returned without connecting to -// the server. -// -// If there is a stale Response, then any validators it contains will be set on the new request -// to give the server a chance to respond with NotModified. If this happens, then the cached Response -// will be returned. -func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) { - log := lg.FromContext(req.Context()) - - cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == "" - var cachedResp *http.Response - if cacheable { - cachedResp, err = t.RespCache.Get(req.Context(), req) - } else { - // Need to invalidate an existing value - if err = t.RespCache.Delete(req.Context()); err != nil { - return nil, err - } - } - - transport := t.Transport - if transport == nil { - transport = http.DefaultTransport - } - - if cacheable && cachedResp != nil && err == nil { - if t.MarkCachedResponses { - cachedResp.Header.Set(XFromCache, "1") - } - - if varyMatches(cachedResp, req) { - // Can only use cached value if the new request doesn't Vary significantly - freshness := getFreshness(cachedResp.Header, req.Header) - if freshness == fresh { - return cachedResp, nil - } - - if freshness == stale { - var req2 *http.Request - // Add validators if caller hasn't already done so - etag := cachedResp.Header.Get("etag") - if etag != "" && req.Header.Get("etag") == "" { - req2 = cloneRequest(req) - req2.Header.Set("if-none-match", etag) - } - lastModified := cachedResp.Header.Get("last-modified") - if lastModified != "" && req.Header.Get("last-modified") == "" { - if req2 == nil { - req2 = cloneRequest(req) - } - req2.Header.Set("if-modified-since", lastModified) - } - if req2 != nil { - req = req2 - } - } - } - - resp, err = transport.RoundTrip(req) - if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified { - // Replace the 304 response with the one from cache, but update with some new headers - endToEndHeaders := getEndToEndHeaders(resp.Header) - for _, header := range endToEndHeaders { - cachedResp.Header[header] = resp.Header[header] - } - resp = cachedResp - } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && - req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) { - // In case of transport failure and stale-if-error activated, returns cached content - // when available - return cachedResp, nil - } else { - if err != nil || resp.StatusCode != http.StatusOK { - lg.WarnIfError(log, msgDeleteCache, t.RespCache.Delete(req.Context())) - } - if err != nil { - return nil, err - } - } - } else { - reqCacheControl := parseCacheControl(req.Header) - if _, ok := reqCacheControl["only-if-cached"]; ok { - resp = newGatewayTimeoutResponse(req) - } else { - resp, err = transport.RoundTrip(req) - if err != nil { - return nil, err - } - } - } - - if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { - for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { - varyKey = http.CanonicalHeaderKey(varyKey) - fakeHeader := "X-Varied-" + varyKey - reqValue := req.Header.Get(varyKey) - if reqValue != "" { - resp.Header.Set(fakeHeader, reqValue) - } - } - switch req.Method { - case "GET": - // Delay caching until EOF is reached. - resp.Body = &cachingReadCloser{ - R: resp.Body, - OnEOF: func(r io.Reader) { - resp := *resp - resp.Body = ioutil.NopCloser(r) - if err := t.RespCache.Write(req.Context(), &resp, nil); err != nil { - log.Error("failed to write download cache", lga.Err, err) - } - }, - } - default: - if err = t.RespCache.Write(req.Context(), resp, nil); err != nil { - log.Error("failed to write download cache", lga.Err, err) - } - } - } else { - lg.WarnIfError(log, "Delete resp cache", t.RespCache.Delete(req.Context())) - } - return resp, nil -} - -// ErrNoDateHeader indicates that the HTTP headers contained no Date header. -var ErrNoDateHeader = errors.New("no Date header") - -// Date parses and returns the value of the Date header. -func Date(respHeaders http.Header) (date time.Time, err error) { - dateHeader := respHeaders.Get("date") - if dateHeader == "" { - err = ErrNoDateHeader - return - } - - return time.Parse(time.RFC1123, dateHeader) -} - -type realClock struct{} - -func (c *realClock) since(d time.Time) time.Duration { - return time.Since(d) -} - -type timer interface { - since(d time.Time) time.Duration -} - -var clock timer = &realClock{} - -// getFreshness will return one of fresh/stale/transparent based on the cache-control -// values of the request and the response -// -// fresh indicates the response can be returned -// stale indicates that the response needs validating before it is returned -// transparent indicates the response should not be used to fulfil the request -// -// Because this is only a private cache, 'public' and 'private' in cache-control aren't -// signficant. Similarly, smax-age isn't used. -func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { - respCacheControl := parseCacheControl(respHeaders) - reqCacheControl := parseCacheControl(reqHeaders) - if _, ok := reqCacheControl["no-cache"]; ok { - return transparent - } - if _, ok := respCacheControl["no-cache"]; ok { - return stale - } - if _, ok := reqCacheControl["only-if-cached"]; ok { - return fresh - } - - date, err := Date(respHeaders) - if err != nil { - return stale - } - currentAge := clock.since(date) - - var lifetime time.Duration - var zeroDuration time.Duration - - // If a response includes both an Expires header and a max-age directive, - // the max-age directive overrides the Expires header, even if the Expires header is more restrictive. - if maxAge, ok := respCacheControl["max-age"]; ok { - lifetime, err = time.ParseDuration(maxAge + "s") - if err != nil { - lifetime = zeroDuration - } - } else { - expiresHeader := respHeaders.Get("Expires") - if expiresHeader != "" { - expires, err := time.Parse(time.RFC1123, expiresHeader) - if err != nil { - lifetime = zeroDuration - } else { - lifetime = expires.Sub(date) - } - } - } - - if maxAge, ok := reqCacheControl["max-age"]; ok { - // the client is willing to accept a response whose age is no greater than the specified time in seconds - lifetime, err = time.ParseDuration(maxAge + "s") - if err != nil { - lifetime = zeroDuration - } - } - if minfresh, ok := reqCacheControl["min-fresh"]; ok { - // the client wants a response that will still be fresh for at least the specified number of seconds. - minfreshDuration, err := time.ParseDuration(minfresh + "s") - if err == nil { - currentAge = time.Duration(currentAge + minfreshDuration) - } - } - - if maxstale, ok := reqCacheControl["max-stale"]; ok { - // Indicates that the client is willing to accept a response that has exceeded its expiration time. - // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded - // its expiration time by no more than the specified number of seconds. - // If no value is assigned to max-stale, then the client is willing to accept a stale response of any age. - // - // Responses served only because of a max-stale value are supposed to have a Warning header added to them, - // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different - // return-value available here. - if maxstale == "" { - return fresh - } - maxstaleDuration, err := time.ParseDuration(maxstale + "s") - if err == nil { - currentAge = time.Duration(currentAge - maxstaleDuration) - } - } - - if lifetime > currentAge { - return fresh - } - - return stale -} - -// Returns true if either the request or the response includes the stale-if-error -// cache control extension: https://tools.ietf.org/html/rfc5861 -func canStaleOnError(respHeaders, reqHeaders http.Header) bool { - respCacheControl := parseCacheControl(respHeaders) - reqCacheControl := parseCacheControl(reqHeaders) - - var err error - lifetime := time.Duration(-1) - - if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok { - if staleMaxAge != "" { - lifetime, err = time.ParseDuration(staleMaxAge + "s") - if err != nil { - return false - } - } else { - return true - } - } - if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok { - if staleMaxAge != "" { - lifetime, err = time.ParseDuration(staleMaxAge + "s") - if err != nil { - return false - } - } else { - return true - } - } - - if lifetime >= 0 { - date, err := Date(respHeaders) - if err != nil { - return false - } - currentAge := clock.since(date) - if lifetime > currentAge { - return true - } - } - - return false -} - -func getEndToEndHeaders(respHeaders http.Header) []string { - // These headers are always hop-by-hop - hopByHopHeaders := map[string]struct{}{ - "Connection": {}, - "Keep-Alive": {}, - "Proxy-Authenticate": {}, - "Proxy-Authorization": {}, - "Te": {}, - "Trailers": {}, - "Transfer-Encoding": {}, - "Upgrade": {}, - } - - for _, extra := range strings.Split(respHeaders.Get("connection"), ",") { - // any header listed in connection, if present, is also considered hop-by-hop - if strings.Trim(extra, " ") != "" { - hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{} - } - } - endToEndHeaders := []string{} - for respHeader := range respHeaders { - if _, ok := hopByHopHeaders[respHeader]; !ok { - endToEndHeaders = append(endToEndHeaders, respHeader) - } - } - return endToEndHeaders -} - -func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) { - if _, ok := respCacheControl["no-store"]; ok { - return false - } - if _, ok := reqCacheControl["no-store"]; ok { - return false - } - return true -} - -func newGatewayTimeoutResponse(req *http.Request) *http.Response { - var braw bytes.Buffer - braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n") - resp, err := http.ReadResponse(bufio.NewReader(&braw), req) - if err != nil { - panic(err) - } - return resp -} - -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map. -// (This function copyright goauth2 authors: https://code.google.com/p/goauth2) -func cloneRequest(r *http.Request) *http.Request { - // shallow copy of the struct - r2 := new(http.Request) - *r2 = *r - if ctx := r.Context(); ctx != nil { - r2 = r2.WithContext(ctx) - } - // deep copy of the Header - r2.Header = make(http.Header) - for k, s := range r.Header { - r2.Header[k] = s - } - return r2 -} - -type cacheControl map[string]string - -func parseCacheControl(headers http.Header) cacheControl { - cc := cacheControl{} - ccHeader := headers.Get("Cache-Control") - for _, part := range strings.Split(ccHeader, ",") { - part = strings.Trim(part, " ") - if part == "" { - continue - } - if strings.ContainsRune(part, '=') { - keyval := strings.Split(part, "=") - cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",") - } else { - cc[part] = "" - } - } - return cc -} - -// headerAllCommaSepValues returns all comma-separated values (each -// with whitespace trimmed) for header name in headers. According to -// Section 4.2 of the HTTP/1.1 spec -// (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2), -// values from multiple occurrences of a header should be concatenated, if -// the header's value is a comma-separated list. -func headerAllCommaSepValues(headers http.Header, name string) []string { - var vals []string - for _, val := range headers[http.CanonicalHeaderKey(name)] { - fields := strings.Split(val, ",") - for i, f := range fields { - fields[i] = strings.TrimSpace(f) - } - vals = append(vals, fields...) - } - return vals -} - -// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF -// handler with a full copy of the content read from R when EOF is -// reached. -type cachingReadCloser struct { - // Underlying ReadCloser. - R io.ReadCloser - // OnEOF is called with a copy of the content of R when EOF is reached. - OnEOF func(io.Reader) - - buf bytes.Buffer // buf stores a copy of the content of R. -} - -// Read reads the next len(p) bytes from R or until R is drained. The -// return value n is the number of bytes read. If R has no data to -// return, err is io.EOF and OnEOF is called with a full copy of what -// has been read so far. -func (r *cachingReadCloser) Read(p []byte) (n int, err error) { - n, err = r.R.Read(p) - r.buf.Write(p[:n]) - if err == io.EOF { - r.OnEOF(bytes.NewReader(r.buf.Bytes())) - } - return n, err -} - -func (r *cachingReadCloser) Close() error { - return r.R.Close() -} diff --git a/libsq/core/ioz/httpcacheworking/httpcache_test.go b/libsq/core/ioz/httpcacheworking/httpcache_test.go deleted file mode 100644 index 51a245237..000000000 --- a/libsq/core/ioz/httpcacheworking/httpcache_test.go +++ /dev/null @@ -1,1486 +0,0 @@ -package httpcacheworking - -import ( - "bytes" - "errors" - "flag" - "github.com/neilotoole/sq/libsq/core/stringz" - "io" - "io/ioutil" - "net/http" - "net/http/httptest" - "os" - "path/filepath" - "strconv" - "testing" - "time" -) - -// newTestTransport returns a new Transport using the in-memory cache implementation -func newTestTransport(cacheDir string, opts ...TransportOpt) *Transport { - rc := NewRespCache(cacheDir) - t := NewTransport(rc, opts...) - return t -} - -var s struct { - server *httptest.Server - client http.Client - transport *Transport - done chan struct{} // Closed to unlock infinite handlers. -} - -type fakeClock struct { - elapsed time.Duration -} - -func (c *fakeClock) since(t time.Time) time.Duration { - return c.elapsed -} - -func TestMain(m *testing.M) { - flag.Parse() - setup() - code := m.Run() - teardown() - os.Exit(code) -} - -func setup() { - tp := newTestTransport(filepath.Join(os.TempDir(), stringz.Uniq8())) - client := http.Client{Transport: tp} - s.transport = tp - s.client = client - s.done = make(chan struct{}) - - mux := http.NewServeMux() - s.server = httptest.NewServer(mux) - - mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - })) - - mux.HandleFunc("/method", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - _, _ = w.Write([]byte(r.Method)) - })) - - mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lm := "Fri, 14 Dec 2010 01:01:50 GMT" - if r.Header.Get("if-modified-since") == lm { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("last-modified", lm) - if r.Header.Get("range") == "bytes=4-9" { - w.WriteHeader(http.StatusPartialContent) - _, _ = w.Write([]byte(" text ")) - return - } - _, _ = w.Write([]byte("Some text content")) - })) - - mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "no-store") - })) - - mux.HandleFunc("/etag", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - etag := "124567" - if r.Header.Get("if-none-match") == etag { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("etag", etag) - })) - - mux.HandleFunc("/lastmodified", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lm := "Fri, 14 Dec 2010 01:01:50 GMT" - if r.Header.Get("if-modified-since") == lm { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("last-modified", lm) - })) - - mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Vary", "Accept") - _, _ = w.Write([]byte("Some text content")) - })) - - mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Vary", "Accept, Accept-Language") - _, _ = w.Write([]byte("Some text content")) - })) - mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Add("Vary", "Accept") - w.Header().Add("Vary", "Accept-Language") - _, _ = w.Write([]byte("Some text content")) - })) - mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Cache-Control", "max-age=3600") - w.Header().Set("Content-Type", "text/plain") - w.Header().Set("Vary", "X-Madeup-Header") - _, _ = w.Write([]byte("Some text content")) - })) - - mux.HandleFunc("/cachederror", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - etag := "abc" - if r.Header.Get("if-none-match") == etag { - w.WriteHeader(http.StatusNotModified) - return - } - w.Header().Set("etag", etag) - w.WriteHeader(http.StatusNotFound) - _, _ = w.Write([]byte("Not found")) - })) - - updateFieldsCounter := 0 - mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-Counter", strconv.Itoa(updateFieldsCounter)) - w.Header().Set("Etag", `"e"`) - updateFieldsCounter++ - if r.Header.Get("if-none-match") != "" { - w.WriteHeader(http.StatusNotModified) - return - } - _, _ = w.Write([]byte("Some text content")) - })) - - // Take 3 seconds to return 200 OK (for testing client timeouts). - mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(3 * time.Second) - })) - - mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for { - select { - case <-s.done: - return - default: - _, _ = w.Write([]byte{0}) - } - } - })) -} - -func teardown() { - close(s.done) - s.server.Close() -} - -func resetTest(t testing.TB) { - s.transport.RespCache = NewRespCache(t.TempDir()) - //s.transport.RespCache.Delete() - clock = &realClock{} -} - -// TestCacheableMethod ensures that uncacheable method does not get stored -// in cache and get incorrectly used for a following cacheable method request. -func TestCacheableMethod(t *testing.T) { - resetTest(t) - { - req, err := http.NewRequest("POST", s.server.URL+"/method", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "POST"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/method", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "GET"; got != want { - t.Errorf("got wrong body %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "" { - t.Errorf("XFromCache header isn't blank") - } - } -} - -func TestDontServeHeadResponseToGetRequest(t *testing.T) { - resetTest(t) - url := s.server.URL + "/" - req, err := http.NewRequest(http.MethodHead, url, nil) - if err != nil { - t.Fatal(err) - } - _, err = s.client.Do(req) - if err != nil { - t.Fatal(err) - } - req, err = http.NewRequest(http.MethodGet, url, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.Header.Get(XFromCache) != "" { - t.Errorf("Cache should not match") - } -} - -func TestDontStorePartialRangeInCache(t *testing.T) { - resetTest(t) - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("range", "bytes=4-9") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), " text "; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusPartialContent { - t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "Some text content"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "" { - t.Error("XFromCache header isn't blank") - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), "Some text content"; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusOK { - t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if resp.Header.Get(XFromCache) != "1" { - t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - { - req, err := http.NewRequest("GET", s.server.URL+"/range", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("range", "bytes=4-9") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - var buf bytes.Buffer - _, err = io.Copy(&buf, resp.Body) - if err != nil { - t.Fatal(err) - } - err = resp.Body.Close() - if err != nil { - t.Fatal(err) - } - if got, want := buf.String(), " text "; got != want { - t.Errorf("got %q, want %q", got, want) - } - if resp.StatusCode != http.StatusPartialContent { - t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) - } - } -} - -func TestCacheOnlyIfBodyRead(t *testing.T) { - resetTest(t) - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - // We do not read the body - resp.Body.Close() - } - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatalf("XFromCache header isn't blank") - } - } -} - -func TestOnlyReadBodyOnDemand(t *testing.T) { - resetTest(t) - - req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) // This shouldn't hang forever. - if err != nil { - t.Fatal(err) - } - buf := make([]byte, 10) // Only partially read the body. - _, err = resp.Body.Read(buf) - if err != nil { - t.Fatal(err) - } - resp.Body.Close() -} - -func TestGetOnlyIfCachedHit(t *testing.T) { - resetTest(t) - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("cache-control", "only-if-cached") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - if resp.StatusCode != http.StatusOK { - t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) - } - } -} - -func TestGetOnlyIfCachedMiss(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("cache-control", "only-if-cached") - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - if resp.StatusCode != http.StatusGatewayTimeout { - t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) - } -} - -func TestGetNoStoreRequest(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Add("Cache-Control", "no-store") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetNoStoreResponse(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetWithEtag(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - // additional assertions to verify that 304 response is converted properly - if resp.StatusCode != http.StatusOK { - t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) - } - if _, ok := resp.Header["Connection"]; ok { - t.Fatalf("Connection header isn't absent") - } - } -} - -func TestGetWithLastModified(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } -} - -func TestGetWithVary(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", "text/plain") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") != "Accept" { - t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary")) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - req.Header.Set("Accept", "text/html") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetWithDoubleVary(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", "text/plain") - req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") == "" { - t.Fatalf(`Vary header is blank`) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - req.Header.Set("Accept-Language", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept-Language", "da") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } -} - -func TestGetWith2VaryHeaders(t *testing.T) { - resetTest(t) - // Tests that multiple Vary headers' comma-separated lists are - // merged. See https://github.com/gregjones/httpcache/issues/27. - const ( - accept = "text/plain" - acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" - ) - req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", accept) - req.Header.Set("Accept-Language", acceptLanguage) - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") == "" { - t.Fatalf(`Vary header is blank`) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } - req.Header.Set("Accept-Language", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept-Language", "da") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept-Language", acceptLanguage) - req.Header.Set("Accept", "") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - } - req.Header.Set("Accept", "image/png") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "" { - t.Fatal("XFromCache header isn't blank") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } -} - -func TestGetVaryUnused(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Accept", "text/plain") - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get("Vary") == "" { - t.Fatalf(`Vary header is blank`) - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - } -} - -func TestUpdateFields(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) - if err != nil { - t.Fatal(err) - } - var counter, counter2 string - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - counter = resp.Header.Get("x-counter") - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.Header.Get(XFromCache) != "1" { - t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) - } - counter2 = resp.Header.Get("x-counter") - } - if counter == counter2 { - t.Fatalf(`both "x-counter" values are equal: %v %v`, counter, counter2) - } -} - -// This tests the fix for https://github.com/gregjones/httpcache/issues/74. -// Previously, after validating a cached response, its StatusCode -// was incorrectly being replaced. -func TestCachedErrorsKeepStatus(t *testing.T) { - resetTest(t) - req, err := http.NewRequest("GET", s.server.URL+"/cachederror", nil) - if err != nil { - t.Fatal(err) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - _, _ = io.Copy(ioutil.Discard, resp.Body) - } - { - resp, err := s.client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusNotFound { - t.Fatalf("Status code isn't 404: %d", resp.StatusCode) - } - } -} - -func TestParseCacheControl(t *testing.T) { - resetTest(t) - h := http.Header{} - for range parseCacheControl(h) { - t.Fatal("cacheControl should be empty") - } - - h.Set("cache-control", "no-cache") - { - cc := parseCacheControl(h) - if _, ok := cc["foo"]; ok { - t.Error(`Value "foo" shouldn't exist`) - } - noCache, ok := cc["no-cache"] - if !ok { - t.Fatalf(`"no-cache" value isn't set`) - } - if noCache != "" { - t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) - } - } - h.Set("cache-control", "no-cache, max-age=3600") - { - cc := parseCacheControl(h) - noCache, ok := cc["no-cache"] - if !ok { - t.Fatalf(`"no-cache" value isn't set`) - } - if noCache != "" { - t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) - } - if cc["max-age"] != "3600" { - t.Fatalf(`"max-age" value isn't "3600": %v`, cc["max-age"]) - } - } -} - -func TestNoCacheRequestExpiration(t *testing.T) { - resetTest(t) - respHeaders := http.Header{} - respHeaders.Set("Cache-Control", "max-age=7200") - - reqHeaders := http.Header{} - reqHeaders.Set("Cache-Control", "no-cache") - if getFreshness(respHeaders, reqHeaders) != transparent { - t.Fatal("freshness isn't transparent") - } -} - -func TestNoCacheResponseExpiration(t *testing.T) { - resetTest(t) - respHeaders := http.Header{} - respHeaders.Set("Cache-Control", "no-cache") - respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestReqMustRevalidate(t *testing.T) { - resetTest(t) - // not paying attention to request setting max-stale means never returning stale - // responses, so always acting as if must-revalidate is set - respHeaders := http.Header{} - - reqHeaders := http.Header{} - reqHeaders.Set("Cache-Control", "must-revalidate") - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestRespMustRevalidate(t *testing.T) { - resetTest(t) - respHeaders := http.Header{} - respHeaders.Set("Cache-Control", "must-revalidate") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestFreshExpiration(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 3 * time.Second} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestMaxAge(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=2") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 3 * time.Second} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestMaxAgeZero(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=0") - - reqHeaders := http.Header{} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestBothMaxAge(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=2") - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "max-age=0") - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestMinFreshWithExpires(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "min-fresh=1") - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - reqHeaders = http.Header{} - reqHeaders.Set("cache-control", "min-fresh=2") - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func TestEmptyMaxStale(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=20") - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "max-stale") - clock = &fakeClock{elapsed: 10 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 60 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } -} - -func TestMaxStaleValue(t *testing.T) { - resetTest(t) - now := time.Now() - respHeaders := http.Header{} - respHeaders.Set("date", now.Format(time.RFC1123)) - respHeaders.Set("cache-control", "max-age=10") - - reqHeaders := http.Header{} - reqHeaders.Set("cache-control", "max-stale=20") - clock = &fakeClock{elapsed: 5 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 15 * time.Second} - if getFreshness(respHeaders, reqHeaders) != fresh { - t.Fatal("freshness isn't fresh") - } - - clock = &fakeClock{elapsed: 30 * time.Second} - if getFreshness(respHeaders, reqHeaders) != stale { - t.Fatal("freshness isn't stale") - } -} - -func containsHeader(headers []string, header string) bool { - for _, v := range headers { - if http.CanonicalHeaderKey(v) == http.CanonicalHeaderKey(header) { - return true - } - } - return false -} - -func TestGetEndToEndHeaders(t *testing.T) { - resetTest(t) - var ( - headers http.Header - end2end []string - ) - - headers = http.Header{} - headers.Set("content-type", "text/html") - headers.Set("te", "deflate") - - end2end = getEndToEndHeaders(headers) - if !containsHeader(end2end, "content-type") { - t.Fatal(`doesn't contain "content-type" header`) - } - if containsHeader(end2end, "te") { - t.Fatal(`doesn't contain "te" header`) - } - - headers = http.Header{} - headers.Set("connection", "content-type") - headers.Set("content-type", "text/csv") - headers.Set("te", "deflate") - end2end = getEndToEndHeaders(headers) - if containsHeader(end2end, "connection") { - t.Fatal(`doesn't contain "connection" header`) - } - if containsHeader(end2end, "content-type") { - t.Fatal(`doesn't contain "content-type" header`) - } - if containsHeader(end2end, "te") { - t.Fatal(`doesn't contain "te" header`) - } - - headers = http.Header{} - end2end = getEndToEndHeaders(headers) - if len(end2end) != 0 { - t.Fatal(`non-zero end2end headers`) - } - - headers = http.Header{} - headers.Set("connection", "content-type") - end2end = getEndToEndHeaders(headers) - if len(end2end) != 0 { - t.Fatal(`non-zero end2end headers`) - } -} - -type transportMock struct { - response *http.Response - err error -} - -func (t transportMock) RoundTrip(req *http.Request) (resp *http.Response, err error) { - return t.response, t.err -} - -func TestStaleIfErrorRequest(t *testing.T) { - resetTest(t) - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := newTestTransport(t.TempDir()) - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - r.Header.Set("Cache-Control", "stale-if-error") - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } -} - -func TestStaleIfErrorRequestLifetime(t *testing.T) { - resetTest(t) - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := newTestTransport(t.TempDir()) - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - r.Header.Set("Cache-Control", "stale-if-error=100") - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - - // Same for http errors - tmock.response = &http.Response{StatusCode: http.StatusInternalServerError} - tmock.err = nil - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - - // If failure last more than max stale, error is returned - clock = &fakeClock{elapsed: 200 * time.Second} - _, err = tp.RoundTrip(r) - if err != tmock.err { - t.Fatalf("got err %v, want %v", err, tmock.err) - } -} - -func TestStaleIfErrorResponse(t *testing.T) { - resetTest(t) - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache, stale-if-error"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := newTestTransport(t.TempDir()) - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } -} - -func TestStaleIfErrorResponseLifetime(t *testing.T) { - resetTest(t) - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusOK), - StatusCode: http.StatusOK, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache, stale-if-error=100"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := newTestTransport(t.TempDir()) - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - - // If failure last more than max stale, error is returned - clock = &fakeClock{elapsed: 200 * time.Second} - _, err = tp.RoundTrip(r) - if err != tmock.err { - t.Fatalf("got err %v, want %v", err, tmock.err) - } -} - -// This tests the fix for https://github.com/gregjones/httpcache/issues/74. -// Previously, after a stale response was used after encountering an error, -// its StatusCode was being incorrectly replaced. -func TestStaleIfErrorKeepsStatus(t *testing.T) { - resetTest(t) - now := time.Now() - tmock := transportMock{ - response: &http.Response{ - Status: http.StatusText(http.StatusNotFound), - StatusCode: http.StatusNotFound, - Header: http.Header{ - "Date": []string{now.Format(time.RFC1123)}, - "Cache-Control": []string{"no-cache"}, - }, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), - }, - err: nil, - } - tp := newTestTransport(t.TempDir()) - tp.Transport = &tmock - - // First time, response is cached on success - r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) - r.Header.Set("Cache-Control", "stale-if-error") - resp, err := tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - _, err = ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - // On failure, response is returned from the cache - tmock.response = nil - tmock.err = errors.New("some error") - resp, err = tp.RoundTrip(r) - if err != nil { - t.Fatal(err) - } - if resp == nil { - t.Fatal("resp is nil") - } - if resp.StatusCode != http.StatusNotFound { - t.Fatalf("Status wasn't 404: %d", resp.StatusCode) - } -} - -// Test that http.Client.Timeout is respected when cache transport is used. -// That is so as long as request cancellation is propagated correctly. -// In the past, that required CancelRequest to be implemented correctly, -// but modern http.Client uses Request.Cancel (or request context) instead, -// so we don't have to do anything. -func TestClientTimeout(t *testing.T) { - if testing.Short() { - t.Skip("skipping timeout test in short mode") // Because it takes at least 3 seconds to run. - } - resetTest(t) - - client := &http.Client{ - Transport: newTestTransport(t.TempDir()), - Timeout: time.Second, - } - started := time.Now() - resp, err := client.Get(s.server.URL + "/3seconds") - taken := time.Since(started) - if err == nil { - t.Error("got nil error, want timeout error") - } - if resp != nil { - t.Error("got non-nil resp, want nil resp") - } - if taken >= 2*time.Second { - t.Error("client.Do took 2+ seconds, want < 2 seconds") - } -} diff --git a/libsq/core/ioz/httpcacheworking/httpz.go b/libsq/core/ioz/httpcacheworking/httpz.go deleted file mode 100644 index 526d2565c..000000000 --- a/libsq/core/ioz/httpcacheworking/httpz.go +++ /dev/null @@ -1,88 +0,0 @@ -package httpcacheworking - -import ( - "bufio" - "fmt" - "io" - "net/http" - "net/textproto" - "strconv" - "strings" -) - -// ReadResponse is a copy of http.ReadResponse, but with the option -// to read only the response header, and not the body. When only reading -// the header, note that resp.Body will be nil, and that the resp is -// generally not functional. -func ReadResponse(r *bufio.Reader, req *http.Request, headerOnly bool) (*http.Response, error) { - if !headerOnly { - return http.ReadResponse(r, req) - } - tp := textproto.NewReader(r) - resp := &http.Response{ - Request: req, - } - - // Parse the first line of the response. - line, err := tp.ReadLine() - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return nil, err - } - proto, status, ok := strings.Cut(line, " ") - if !ok { - return nil, badStringError("malformed HTTP response", line) - } - resp.Proto = proto - resp.Status = strings.TrimLeft(status, " ") - - statusCode, _, _ := strings.Cut(resp.Status, " ") - if len(statusCode) != 3 { - return nil, badStringError("malformed HTTP status code", statusCode) - } - resp.StatusCode, err = strconv.Atoi(statusCode) - if err != nil || resp.StatusCode < 0 { - return nil, badStringError("malformed HTTP status code", statusCode) - } - if resp.ProtoMajor, resp.ProtoMinor, ok = http.ParseHTTPVersion(resp.Proto); !ok { - return nil, badStringError("malformed HTTP version", resp.Proto) - } - - // Parse the response headers. - mimeHeader, err := tp.ReadMIMEHeader() - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return nil, err - } - resp.Header = http.Header(mimeHeader) - - fixPragmaCacheControl(resp.Header) - - //err = readTransfer(resp, r) - //if err != nil { - // return nil, err - //} - - return resp, nil -} - -// RFC 7234, section 5.4: Should treat -// -// Pragma: no-cache -// -// like -// -// Cache-Control: no-cache -func fixPragmaCacheControl(header http.Header) { - if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { - if _, presentcc := header["Cache-Control"]; !presentcc { - header["Cache-Control"] = []string{"no-cache"} - } - } -} - -func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } diff --git a/libsq/core/ioz/httpcacheworking/respcache.go b/libsq/core/ioz/httpcacheworking/respcache.go deleted file mode 100644 index e4c5c830a..000000000 --- a/libsq/core/ioz/httpcacheworking/respcache.go +++ /dev/null @@ -1,210 +0,0 @@ -package httpcacheworking - -import ( - "bufio" - "bytes" - "context" - "github.com/neilotoole/sq/libsq/core/cleanup" - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "io" - "net/http" - "net/http/httputil" - "os" - "path/filepath" - "sync" -) - -// NewRespCache returns a new instance that stores responses in cacheDir. -// The caller should call RespCache.Close when finished with the cache. -func NewRespCache(cacheDir string) *RespCache { - c := &RespCache{ - Dir: cacheDir, - //Header: filepath.Join(cacheDir, "header"), - //Body: filepath.Join(cacheDir, "body"), - clnup: cleanup.New(), - } - return c -} - -// RespCache is a cache for a single http.Response. The response is -// stored in two files, one for the header and one for the body. -// The caller should call RespCache.Close when finished with the cache. -type RespCache struct { - mu sync.Mutex - clnup *cleanup.Cleanup - - Dir string -} - -// Paths returns the paths to the header and body files for req. -// It is not guaranteed that they exist. -func (rc *RespCache) Paths(req *http.Request) (header, body string) { - if req == nil || req.Method == http.MethodGet { - return filepath.Join(rc.Dir, "header"), filepath.Join(rc.Dir, "body") - } - - return filepath.Join(rc.Dir, req.Method+"_header"), - filepath.Join(rc.Dir, req.Method+"_body") -} - -// Exists returns true if the cache contains a response for req. -func (rc *RespCache) Exists(req *http.Request) bool { - rc.mu.Lock() - defer rc.mu.Unlock() - - fpHeader, _ := rc.Paths(req) - fi, err := os.Stat(fpHeader) - if err != nil { - return false - } - return fi.Size() > 0 -} - -// Get returns the cached http.Response for req if present, and nil -// otherwise. -func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response, error) { - rc.mu.Lock() - defer rc.mu.Unlock() - - fpHeader, fpBody := rc.Paths(req) - - if !ioz.FileAccessible(fpHeader) { - return nil, nil - } - - headerBytes, err := os.ReadFile(fpHeader) - if err != nil { - return nil, err - } - - bodyFile, err := os.Open(fpBody) - if err != nil { - lg.FromContext(ctx).Error("failed to open cached response body", - lga.File, fpBody, lga.Err, err) - return nil, err - } - - // We need to explicitly close bodyFile at some later point. It won't be - // closed via a call to http.Response.Body.Close(). - rc.clnup.AddC(bodyFile) - // TODO: consider adding contextio.NewReader? - concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) - return http.ReadResponse(bufio.NewReader(concatRdr), req) -} - -// Close closes the cache, freeing any resources it holds. Note that -// it does not delete the cache: for that, see RespCache.Delete. -func (rc *RespCache) Close() error { - rc.mu.Lock() - defer rc.mu.Unlock() - - err := rc.clnup.Run() - rc.clnup = cleanup.New() - return err -} - -// Delete deletes the cache entries from disk. -func (rc *RespCache) Delete(ctx context.Context) error { - if rc == nil { - return nil - } - rc.mu.Lock() - defer rc.mu.Unlock() - - return rc.doDelete(ctx) -} - -func (rc *RespCache) doDelete(ctx context.Context) error { - cleanErr := rc.clnup.Run() - rc.clnup = cleanup.New() - deleteErr := errz.Wrap(os.RemoveAll(rc.Dir), "delete cache dir") - err := errz.Combine(cleanErr, deleteErr) - if err != nil { - lg.FromContext(ctx).Error("Delete cache dir", - lga.Dir, rc.Dir, lga.Err, err) - return err - } - - lg.FromContext(ctx).Info("Deleted cache dir", lga.Dir, rc.Dir) - return nil -} - -const msgDeleteCache = "Delete HTTP response cache" - -// Write writes resp to the cache. If copyWrtr is non-nil, the response -// bytes are copied to that destination also. -func (rc *RespCache) Write(ctx context.Context, resp *http.Response, copyWrtr io.WriteCloser) error { - rc.mu.Lock() - defer rc.mu.Unlock() - - log := lg.FromContext(ctx) - log.Debug("huzzah in write") - - err := rc.doWrite(ctx, resp, copyWrtr) - if err != nil { - lg.WarnIfError(lg.FromContext(ctx), msgDeleteCache, rc.doDelete(ctx)) - } - return err -} - -func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, copyWrtr io.WriteCloser) error { - log := lg.FromContext(ctx) - - if err := ioz.RequireDir(rc.Dir); err != nil { - return err - } - - fpHeader, fpBody := rc.Paths(resp.Request) - - headerBytes, err := httputil.DumpResponse(resp, false) - if err != nil { - return err - } - - if _, err = ioz.WriteToFile(ctx, fpHeader, bytes.NewReader(headerBytes)); err != nil { - return err - } - - cacheFile, err := os.OpenFile(fpBody, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) - if err != nil { - return err - } - - var cr io.Reader - if copyWrtr == nil { - cr = contextio.NewReader(ctx, resp.Body) - } else { - tr := io.TeeReader(resp.Body, copyWrtr) - cr = contextio.NewReader(ctx, tr) - } - - //if copyWrtr != nil { - // cr = io.TeeReader(cr, copyWrtr) - //} - var written int64 - written, err = io.Copy(cacheFile, cr) - if err != nil { - lg.WarnIfCloseError(log, "Close cache body file", cacheFile) - return err - } - if copyWrtr != nil { - lg.WarnIfCloseError(log, "Close copy writer", copyWrtr) - } - - if err = cacheFile.Close(); err != nil { - return err - } - - log.Info("Wrote HTTP response to cache", lga.File, fpBody, lga.Size, written) - cacheFile, err = os.Open(fpBody) - if err != nil { - return err - } - - resp.Body = cacheFile - return nil -} diff --git a/libsq/source/dl.go b/libsq/source/dl.go index 9eea6bcda..200c4df85 100644 --- a/libsq/source/dl.go +++ b/libsq/source/dl.go @@ -65,8 +65,7 @@ func (d *downloader2) ClearCache(ctx context.Context) error { d.mu.Lock() defer d.mu.Unlock() - if err := d.tp.RespCache.Delete(ctx); err != nil { - //log.Error("Failed to delete cache dir", lga.Dir, d.cacheDir, lga.Err, err) + if err := d.tp.Delete(ctx); err != nil { return errz.Wrapf(err, "failed to clear cache dir: %s", d.cacheDir) } @@ -136,22 +135,22 @@ func (d *downloader2) Download2(ctx context.Context, dest io.Writer) (written in var gotFp string var gotErr error //buf := &bytes.Buffer{} - cb := httpcache.CallbackHandler{ - HandleCached: func(cachedFilepath string) error { + cb := httpcache.Handler{ + Cached: func(cachedFilepath string) error { gotFp = cachedFilepath return nil }, - HandleUncached: func() (wc io.WriteCloser, errFn func(error), err error) { + Uncached: func() (wc io.WriteCloser, errFn func(error), err error) { return destWrtr, func(err error) { gotErr = err }, nil }, - HandleError: func(err error) { + Error: func(err error) { gotErr = err }, } - d.tp.Fetch(req, cb) + d.tp.FetchWith(req, cb) _ = gotFp _ = gotErr From 2a9fb3f9b5c32e760ea075fe4a3c48ee538a4d01 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 05:41:06 -0700 Subject: [PATCH 100/195] wip: refactoring download --- cli/cmd_version.go | 2 +- cli/complete.go | 2 +- .../ioz/{httpcache => download}/LICENSE.txt | 0 .../ioz/{httpcache => download}/README.md | 0 libsq/core/ioz/download/download.go | 401 ++++++++++++++++++ .../{httpcache => download}/download_test.go | 36 +- .../{httpcache => download}/httpcache_test.go | 24 +- .../core/ioz/{httpcache => download}/httpz.go | 38 +- .../ioz/{httpcache => download}/respcache.go | 17 +- libsq/core/ioz/httpcache/httpcache.go | 329 -------------- libsq/source/dl.go | 28 +- 11 files changed, 468 insertions(+), 409 deletions(-) rename libsq/core/ioz/{httpcache => download}/LICENSE.txt (100%) rename libsq/core/ioz/{httpcache => download}/README.md (100%) create mode 100644 libsq/core/ioz/download/download.go rename libsq/core/ioz/{httpcache => download}/download_test.go (66%) rename libsq/core/ioz/{httpcache => download}/httpcache_test.go (98%) rename libsq/core/ioz/{httpcache => download}/httpz.go (93%) rename libsq/core/ioz/{httpcache => download}/respcache.go (92%) delete mode 100644 libsq/core/ioz/httpcache/httpcache.go diff --git a/cli/cmd_version.go b/cli/cmd_version.go index 716f96aaa..e037a1b93 100644 --- a/cli/cmd_version.go +++ b/cli/cmd_version.go @@ -90,7 +90,7 @@ func execVersion(cmd *cobra.Command, _ []string) error { var err error v, err := fetchBrewVersion(ctx) if err != nil { - lg.Error(logFrom(cmd), "Fetch brew version", err) + lg.Error(logFrom(cmd), "Get brew version", err) } // OK if v is empty diff --git a/cli/complete.go b/cli/complete.go index 4db48df42..abd0218cd 100644 --- a/cli/complete.go +++ b/cli/complete.go @@ -675,7 +675,7 @@ func (c *handleTableCompleter) completeHandle(ctx context.Context, ru *run.Run, // This means that we aren't able to get metadata for this source. // This could be because the source is temporarily offline. The // best we can do is just to return the handle, without the tables. - lg.WarnIfError(lg.FromContext(ctx), "Fetch metadata", err) + lg.WarnIfError(lg.FromContext(ctx), "Get metadata", err) return matchingHandles, cobra.ShellCompDirectiveNoFileComp | cobra.ShellCompDirectiveNoSpace } diff --git a/libsq/core/ioz/httpcache/LICENSE.txt b/libsq/core/ioz/download/LICENSE.txt similarity index 100% rename from libsq/core/ioz/httpcache/LICENSE.txt rename to libsq/core/ioz/download/LICENSE.txt diff --git a/libsq/core/ioz/httpcache/README.md b/libsq/core/ioz/download/README.md similarity index 100% rename from libsq/core/ioz/httpcache/README.md rename to libsq/core/ioz/download/README.md diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go new file mode 100644 index 000000000..23274ed88 --- /dev/null +++ b/libsq/core/ioz/download/download.go @@ -0,0 +1,401 @@ +// Package download provides a http.RoundTripper implementation that +// works as a mostly RFC-compliant cache for http responses. +// +// FIXME: move download to internal/download, because its use +// is so specialized? +// +// Acknowledgement: This package is a heavily customized fork +// of https://github.com/gregjones/httpcache, via bitcomplete/download. +package download + +import ( + "bufio" + "context" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "io" + "net/http" + "os" +) + +// State is an enumeration of caching states based on the cache-control +// values of the request and the response. +// +// - Uncached indicates the item is not cached. +// - Fresh indicates that the cached item can be returned. +// - Stale indicates that the cached item needs validating before it is returned. +// - Transparent indicates the cached item should not be used to fulfil the request. +// +// Because this is only a private cache, 'public' and 'private' in cache-control aren't +// significant. Similarly, smax-age isn't used. +type State int + +const ( + // Uncached indicates that the item is not cached. + Uncached State = iota + + // Stale indicates that the cached item needs validating before it is returned. + Stale + + // Fresh indicates the cached item can be returned. + Fresh + + // Transparent indicates the cached item should not be used to fulfil the request. + Transparent +) + +// XFromCache is the header added to responses that are returned from the cache +const XFromCache = "X-From-Cache" + +// Opt is a configuration option for creating a new Download. +type Opt func(t *Download) + +// OptMarkCacheResponses configures a Download by setting +// Download.markCachedResponses to true. +func OptMarkCacheResponses(markCachedResponses bool) Opt { + return func(t *Download) { + t.markCachedResponses = markCachedResponses + } +} + +// OptInsecureSkipVerify configures a Download to skip TLS verification. +func OptInsecureSkipVerify(insecureSkipVerify bool) Opt { + return func(t *Download) { + t.InsecureSkipVerify = insecureSkipVerify + } +} + +// OptDisableCaching disables the cache. +func OptDisableCaching(disable bool) Opt { + return func(t *Download) { + t.disableCaching = disable + } +} + +// OptUserAgent sets the User-Agent header on requests. +func OptUserAgent(userAgent string) Opt { + return func(t *Download) { + t.userAgent = userAgent + } +} + +// Download is aan implementation of http.RoundTripper that will return values from a cache +// where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) +// to repeated requests allowing servers to return 304 / Not Modified +type Download struct { + // FIXME: Does Download need a sync.Mutex? + + // FIXME: implement url mechanism + // url is the URL of the download. + url string + + // The RoundTripper interface actually used to make requests + // If nil, http.DefaultTransport is used. + transport http.RoundTripper + + // respCache is the cache used to store responses. + respCache *RespCache + + // markCachedResponses, if true, indicates that responses returned from the + // cache will be given an extra header, X-From-Cache + markCachedResponses bool + + InsecureSkipVerify bool + + userAgent string + + disableCaching bool +} + +// New returns a new Download that uses cacheDir as the cache directory. +func New(cacheDir string, opts ...Opt) *Download { + t := &Download{ + markCachedResponses: true, + disableCaching: false, + InsecureSkipVerify: false, + } + for _, opt := range opts { + opt(t) + } + + if !t.disableCaching { + t.respCache = NewRespCache(cacheDir) + } + return t +} + +// Handler is a callback invoked by Download.Get. Exactly one of the +// handler functions will be invoked, one time. +type Handler struct { + // Cached is invoked when the download is already cached on disk. The + // fp arg is the path to the downloaded file. + Cached func(fp string) + + // Uncached is invoked when the download is not cached. The handler must + // return an io.WriterCloser, which the download contents will be written + // to (as well as being written to the disk cache). On success, the dest + // io.WriteCloser is closed. If an error occurs during download or + // writing, errFn is invoked, and dest is not closed. + Uncached func() (dest io.WriteCloser, errFn func(error)) + + // Error is invoked if an + Error func(err error) +} + +// Get gets the download at url, invoking h as appropriate. +func (dl *Download) Get(ctx context.Context, url string, h Handler) { + req, err := dl.newRequest(ctx, url) + if err != nil { + h.Error(err) + return + } + + dl.get(req, h) +} + +func (dl *Download) get(req *http.Request, cb Handler) { + ctx := req.Context() + log := lg.FromContext(ctx) + log.Info("Fetching download", lga.URL, req.URL.String()) + _, fpBody := dl.respCache.Paths(req) + + if dl.state(req) == Fresh { + cb.Cached(fpBody) + return + } + + var err error + cacheable := dl.isCacheable(req) + var cachedResp *http.Response + if cacheable { + cachedResp, err = dl.respCache.Get(req.Context(), req) + } else { + // Need to invalidate an existing value + if err = dl.respCache.Clear(req.Context()); err != nil { + cb.Error(err) + return + } + } + + transport := dl.transport + if transport == nil { + transport = http.DefaultTransport + } + var resp *http.Response + if cacheable && cachedResp != nil && err == nil { + if dl.markCachedResponses { + cachedResp.Header.Set(XFromCache, "1") + } + + if varyMatches(cachedResp, req) { + // Can only use cached value if the new request doesn't Vary significantly + freshness := getFreshness(cachedResp.Header, req.Header) + if freshness == Fresh { + cb.Cached(fpBody) + return + } + + if freshness == Stale { + var req2 *http.Request + // Add validators if caller hasn't already done so + etag := cachedResp.Header.Get("etag") + if etag != "" && req.Header.Get("etag") == "" { + req2 = cloneRequest(req) + req2.Header.Set("if-none-match", etag) + } + lastModified := cachedResp.Header.Get("last-modified") + if lastModified != "" && req.Header.Get("last-modified") == "" { + if req2 == nil { + req2 = cloneRequest(req) + } + req2.Header.Set("if-modified-since", lastModified) + } + if req2 != nil { + req = req2 + } + } + } + + // FIXME: Use an http client here + resp, err = transport.RoundTrip(req) + if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { + // Replace the 304 response with the one from cache, but update with some new headers + endToEndHeaders := getEndToEndHeaders(resp.Header) + for _, header := range endToEndHeaders { + cachedResp.Header[header] = resp.Header[header] + } + resp = cachedResp + } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && + req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header) { + // In case of transport failure and stale-if-error activated, returns cached content + // when available + log.Warn("Returning cached response due to transport failure", lga.Err, err) + cb.Cached(fpBody) + return + } else { + if err != nil || resp.StatusCode != http.StatusOK { + lg.WarnIfError(log, msgDeleteCache, dl.respCache.Clear(req.Context())) + } + if err != nil { + cb.Error(err) + return + } + } + } else { + reqCacheControl := parseCacheControl(req.Header) + if _, ok := reqCacheControl["only-if-cached"]; ok { + resp = newGatewayTimeoutResponse(req) + } else { + resp, err = transport.RoundTrip(req) + if err != nil { + cb.Error(err) + return + } + } + } + + if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { + for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { + varyKey = http.CanonicalHeaderKey(varyKey) + fakeHeader := "X-Varied-" + varyKey + reqValue := req.Header.Get(varyKey) + if reqValue != "" { + resp.Header.Set(fakeHeader, reqValue) + } + } + + copyWrtr, errFn := cb.Uncached() + if copyWrtr == nil { + log.Warn("nil copy writer from download handler; returning") + return + } + + if err = dl.respCache.Write(req.Context(), resp, copyWrtr); err != nil { + log.Error("failed to write download cache", lga.Err, err) + errFn(err) + cb.Error(err) + } + return + } else { + lg.WarnIfError(log, "Delete resp cache", dl.respCache.Clear(req.Context())) + } + + // It's not cacheable, so we need to write it to the copyWrtr. + copyWrtr, errFn := cb.Uncached() + if copyWrtr == nil { + log.Warn("nil copy writer from download handler; returning") + return + } + + cr := contextio.NewReader(ctx, resp.Body) + _, err = io.Copy(copyWrtr, cr) + if err != nil { + errFn(err) + cb.Error(err) + return + } + if err = copyWrtr.Close(); err != nil { + cb.Error(err) + return + } + + return +} + +// Close frees any resources held by the Download. It does not delete +// the cache from disk. For that, see Download.Clear. +func (dl *Download) Close() error { + if dl.respCache != nil { + return dl.respCache.Close() + } + return nil +} + +func (dl *Download) newRequest(ctx context.Context, url string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + lg.FromContext(ctx).Error("Failed to create request", lga.URL, url, lga.Err, err) + return nil, err + } + if dl.userAgent != "" { + req.Header.Set("User-Agent", dl.userAgent) + } + return req, nil +} + +func (dl *Download) getClient() *http.Client { + return ioz.NewHTTPClient(dl.InsecureSkipVerify) +} + +// Clear deletes the cache. +func (dl *Download) Clear(ctx context.Context) error { + if dl.respCache != nil { + return dl.respCache.Clear(ctx) + } + return nil +} + +// IsCached returns true if there is a cache entry for url. Success does not +// guarantee that the cache entry is fresh. See also: [Download.IsFresh]. +func (dl *Download) IsCached(ctx context.Context, url string) bool { + req, err := dl.newRequest(ctx, url) + if err != nil { + return false + } + return dl.isCached(req) +} + +func (dl *Download) isCached(req *http.Request) bool { + if dl.disableCaching { + return false + } + return dl.respCache.Exists(req) +} + +// State returns the cache state of url. +func (dl *Download) State(ctx context.Context, url string) State { + req, err := dl.newRequest(ctx, url) + if err != nil { + return Uncached + } + return dl.state(req) +} + +func (dl *Download) state(req *http.Request) State { + if !dl.isCacheable(req) { + return Uncached + } + + ctx := req.Context() + log := lg.FromContext(ctx) + + if !dl.respCache.Exists(req) { + return Uncached + } + + fpHeader, _ := dl.respCache.Paths(req) + f, err := os.Open(fpHeader) + if err != nil { + log.Error("Failed to open cached response header file", lga.File, fpHeader, lga.Err, err) + return Uncached + } + + defer lg.WarnIfCloseError(log, "Close cached response header file", f) + + cachedResp, err := readResponseHeader(bufio.NewReader(f), nil) + if err != nil { + log.Error("Failed to read cached response header", lga.Err, err) + return Uncached + } + + return getFreshness(cachedResp.Header, req.Header) +} + +func (dl *Download) isCacheable(req *http.Request) bool { + if dl.disableCaching { + return false + } + return (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get("range") == "" +} diff --git a/libsq/core/ioz/httpcache/download_test.go b/libsq/core/ioz/download/download_test.go similarity index 66% rename from libsq/core/ioz/httpcache/download_test.go rename to libsq/core/ioz/download/download_test.go index c8584fe74..f1700df9e 100644 --- a/libsq/core/ioz/httpcache/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -1,11 +1,11 @@ -package httpcache_test +package download_test import ( "bytes" "context" "github.com/neilotoole/slogt" "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/httpcache" + "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/sq/libsq/core/lg" "github.com/stretchr/testify/require" "io" @@ -21,17 +21,18 @@ const ( sizeGzipActorCSV = int64(1968) ) -func TestTransport_Fetch(t *testing.T) { +func TestDownload(t *testing.T) { log := slogt.New(t) ctx := lg.NewContext(context.Background(), log) const dlURL = urlActorCSV + // FIXME: switch to temp dir cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-2")) require.NoError(t, err) t.Logf("cacheDir: %s", cacheDir) - dl := httpcache.NewTransport(cacheDir, httpcache.OptUserAgent("sq/dev")) - require.NoError(t, dl.Delete(ctx)) + dl := download.New(cacheDir, download.OptUserAgent("sq/dev")) + require.NoError(t, dl.Clear(ctx)) var ( destBuf = &bytes.Buffer{} @@ -44,34 +45,31 @@ func TestTransport_Fetch(t *testing.T) { gotErr = nil } - h := httpcache.Handler{ - Cached: func(cachedFilepath string) error { + h := download.Handler{ + Cached: func(cachedFilepath string) { gotFp = cachedFilepath - return nil }, - Uncached: func() (wc io.WriteCloser, errFn func(error), err error) { + Uncached: func() (wc io.WriteCloser, errFn func(error)) { return ioz.WriteCloser(destBuf), func(err error) { gotErr = err - }, - nil + } }, Error: func(err error) { gotErr = err }, } - //req, err := http.NewRequestWithContext(ctx, http.MethodGet, dlURL, nil) - ////if d.userAgent != "" { - //// req.Header.Set("User-Agent", d.userAgent) - ////} - dl.Fetch(ctx, dlURL, h) + require.Equal(t, download.Uncached, dl.State(ctx, dlURL)) + dl.Get(ctx, dlURL, h) require.NoError(t, gotErr) require.Empty(t, gotFp) require.Equal(t, sizeActorCSV, int64(destBuf.Len())) + require.Equal(t, download.Fresh, dl.State(ctx, dlURL)) + reset() - dl.Fetch(ctx, dlURL, h) + dl.Get(ctx, dlURL, h) require.NoError(t, gotErr) require.Equal(t, 0, destBuf.Len()) require.NotEmpty(t, gotFp) @@ -79,4 +77,8 @@ func TestTransport_Fetch(t *testing.T) { require.NoError(t, err) require.Equal(t, sizeActorCSV, int64(len(gotFileBytes))) + require.Equal(t, download.Fresh, dl.State(ctx, dlURL)) + + require.NoError(t, dl.Clear(ctx)) + require.Equal(t, download.Uncached, dl.State(ctx, dlURL)) } diff --git a/libsq/core/ioz/httpcache/httpcache_test.go b/libsq/core/ioz/download/httpcache_test.go similarity index 98% rename from libsq/core/ioz/httpcache/httpcache_test.go rename to libsq/core/ioz/download/httpcache_test.go index da7a93250..ebb081461 100644 --- a/libsq/core/ioz/httpcache/httpcache_test.go +++ b/libsq/core/ioz/download/httpcache_test.go @@ -1,16 +1,16 @@ -package httpcache +package download // -//// newTestTransport returns a new Transport using the in-memory cache implementation -//func newTestTransport(cacheDir string, opts ...Opt) *Transport { -// t := NewTransport(cacheDir, opts...) +//// newTestTransport returns a new Download using the in-memory cache implementation +//func newTestTransport(cacheDir string, opts ...Opt) *Download { +// t := New(cacheDir, opts...) // return t //} // //var s struct { // server *httptest.Server // client http.Client -// transport *Transport +// transport *Download // done chan struct{} // Closed to unlock infinite handlers. //} // @@ -32,7 +32,7 @@ package httpcache // //func setup() { // tp := newTestTransport(filepath.Join(os.TempDir(), stringz.Uniq8())) -// client := http.Client{Transport: tp} +// client := http.Client{Download: tp} // s.transport = tp // s.client = client // s.done = make(chan struct{}) @@ -1203,7 +1203,7 @@ package httpcache // err: nil, // } // tp := newTestTransport(t.TempDir()) -// tp.Transport = &tmock +// tp.Download = &tmock // // // First time, response is cached on success // r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) @@ -1248,7 +1248,7 @@ package httpcache // err: nil, // } // tp := newTestTransport(t.TempDir()) -// tp.Transport = &tmock +// tp.Download = &tmock // // // First time, response is cached on success // r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) @@ -1311,7 +1311,7 @@ package httpcache // err: nil, // } // tp := newTestTransport(t.TempDir()) -// tp.Transport = &tmock +// tp.Download = &tmock // // // First time, response is cached on success // r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) @@ -1355,7 +1355,7 @@ package httpcache // err: nil, // } // tp := newTestTransport(t.TempDir()) -// tp.Transport = &tmock +// tp.Download = &tmock // // // First time, response is cached on success // r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) @@ -1409,7 +1409,7 @@ package httpcache // err: nil, // } // tp := newTestTransport(t.TempDir()) -// tp.Transport = &tmock +// tp.Download = &tmock // // // First time, response is cached on success // r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) @@ -1453,7 +1453,7 @@ package httpcache // resetTest(t) // // client := &http.Client{ -// Transport: newTestTransport(t.TempDir()), +// Download: newTestTransport(t.TempDir()), // Timeout: time.Second, // } // started := time.Now() diff --git a/libsq/core/ioz/httpcache/httpz.go b/libsq/core/ioz/download/httpz.go similarity index 93% rename from libsq/core/ioz/httpcache/httpz.go rename to libsq/core/ioz/download/httpz.go index c4507e078..84c4134bd 100644 --- a/libsq/core/ioz/httpcache/httpz.go +++ b/libsq/core/ioz/download/httpz.go @@ -1,4 +1,4 @@ -package httpcache +package download import ( "bufio" @@ -105,31 +105,31 @@ type timer interface { var clock timer = &realClock{} -// getFreshness will return one of fresh/stale/transparent based on the cache-control +// getFreshness will return one of Fresh/stale/transparent based on the cache-control // values of the request and the response // -// fresh indicates the response can be returned +// Fresh indicates the response can be returned // stale indicates that the response needs validating before it is returned -// transparent indicates the response should not be used to fulfil the request +// Transparent indicates the response should not be used to fulfil the request // // Because this is only a private cache, 'public' and 'private' in cache-control aren't // significant. Similarly, smax-age isn't used. -func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { +func getFreshness(respHeaders, reqHeaders http.Header) (freshness State) { respCacheControl := parseCacheControl(respHeaders) reqCacheControl := parseCacheControl(reqHeaders) if _, ok := reqCacheControl["no-cache"]; ok { - return transparent + return Transparent } if _, ok := respCacheControl["no-cache"]; ok { - return stale + return Stale } if _, ok := reqCacheControl["only-if-cached"]; ok { - return fresh + return Fresh } date, err := getDate(respHeaders) if err != nil { - return stale + return Stale } currentAge := clock.since(date) @@ -162,15 +162,15 @@ func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { lifetime = zeroDuration } } - if minfresh, ok := reqCacheControl["min-fresh"]; ok { + if minFresh, ok := reqCacheControl["min-fresh"]; ok { // the client wants a response that will still be fresh for at least the specified number of seconds. - minfreshDuration, err := time.ParseDuration(minfresh + "s") + minFreshDuration, err := time.ParseDuration(minFresh + "s") if err == nil { - currentAge = time.Duration(currentAge + minfreshDuration) + currentAge = time.Duration(currentAge + minFreshDuration) } } - if maxstale, ok := reqCacheControl["max-stale"]; ok { + if maxStale, ok := reqCacheControl["max-stale"]; ok { // Indicates that the client is willing to accept a response that has exceeded its expiration time. // If max-stale is assigned a value, then the client is willing to accept a response that has exceeded // its expiration time by no more than the specified number of seconds. @@ -179,20 +179,20 @@ func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) { // Responses served only because of a max-stale value are supposed to have a Warning header added to them, // but that seems like a hassle, and is it actually useful? If so, then there needs to be a different // return-value available here. - if maxstale == "" { - return fresh + if maxStale == "" { + return Fresh } - maxstaleDuration, err := time.ParseDuration(maxstale + "s") + maxStaleDuration, err := time.ParseDuration(maxStale + "s") if err == nil { - currentAge = time.Duration(currentAge - maxstaleDuration) + currentAge = time.Duration(currentAge - maxStaleDuration) } } if lifetime > currentAge { - return fresh + return Fresh } - return stale + return Stale } // Returns true if either the request or the response includes the stale-if-error diff --git a/libsq/core/ioz/httpcache/respcache.go b/libsq/core/ioz/download/respcache.go similarity index 92% rename from libsq/core/ioz/httpcache/respcache.go rename to libsq/core/ioz/download/respcache.go index 6b964c2ed..2aa88eab1 100644 --- a/libsq/core/ioz/httpcache/respcache.go +++ b/libsq/core/ioz/download/respcache.go @@ -1,4 +1,4 @@ -package httpcache +package download import ( "bufio" @@ -107,24 +107,24 @@ func (rc *RespCache) Close() error { return err } -// Delete deletes the cache entries from disk. -func (rc *RespCache) Delete(ctx context.Context) error { +// Clear deletes the cache entries from disk. +func (rc *RespCache) Clear(ctx context.Context) error { if rc == nil { return nil } rc.mu.Lock() defer rc.mu.Unlock() - return rc.doDelete(ctx) + return rc.doClear(ctx) } -func (rc *RespCache) doDelete(ctx context.Context) error { +func (rc *RespCache) doClear(ctx context.Context) error { cleanErr := rc.clnup.Run() rc.clnup = cleanup.New() deleteErr := errz.Wrap(os.RemoveAll(rc.Dir), "delete cache dir") err := errz.Combine(cleanErr, deleteErr) if err != nil { - lg.FromContext(ctx).Error("Delete cache dir", + lg.FromContext(ctx).Error(msgDeleteCache, lga.Dir, rc.Dir, lga.Err, err) return err } @@ -141,12 +141,9 @@ func (rc *RespCache) Write(ctx context.Context, resp *http.Response, copyWrtr io rc.mu.Lock() defer rc.mu.Unlock() - log := lg.FromContext(ctx) - log.Debug("huzzah in write") - err := rc.doWrite(ctx, resp, copyWrtr) if err != nil { - lg.WarnIfError(lg.FromContext(ctx), msgDeleteCache, rc.doDelete(ctx)) + lg.WarnIfError(lg.FromContext(ctx), msgDeleteCache, rc.doClear(ctx)) } return err } diff --git a/libsq/core/ioz/httpcache/httpcache.go b/libsq/core/ioz/httpcache/httpcache.go deleted file mode 100644 index 432d8464b..000000000 --- a/libsq/core/ioz/httpcache/httpcache.go +++ /dev/null @@ -1,329 +0,0 @@ -// Package httpcache provides a http.RoundTripper implementation that -// works as a mostly RFC-compliant cache for http responses. -// -// FIXME: move httpcache to internal/httpcache, because its use -// is so specialized? -// -// Acknowledgement: This package is a heavily customized fork -// of https://github.com/gregjones/httpcache, via bitcomplete/httpcache. -package httpcache - -import ( - "bufio" - "context" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "io" - "net/http" - "os" -) - -const ( - stale = iota - fresh - transparent - // XFromCache is the header added to responses that are returned from the cache - XFromCache = "X-From-Cache" -) - -// Opt is a configuration option for creating a new Transport. -type Opt func(t *Transport) - -// OptMarkCacheResponses configures a Transport by setting -// Transport.markCachedResponses to true. -func OptMarkCacheResponses(markCachedResponses bool) Opt { - return func(t *Transport) { - t.markCachedResponses = markCachedResponses - } -} - -// OptInsecureSkipVerify configures a Transport to skip TLS verification. -func OptInsecureSkipVerify(insecureSkipVerify bool) Opt { - return func(t *Transport) { - t.InsecureSkipVerify = insecureSkipVerify - } -} - -// OptDisableCaching disables the cache. -func OptDisableCaching(disable bool) Opt { - return func(t *Transport) { - t.disableCaching = disable - } -} - -// OptUserAgent sets the User-Agent header on requests. -func OptUserAgent(userAgent string) Opt { - return func(t *Transport) { - t.userAgent = userAgent - } -} - -// Transport is an implementation of http.RoundTripper that will return values from a cache -// where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) -// to repeated requests allowing servers to return 304 / Not Modified -type Transport struct { - // The RoundTripper interface actually used to make requests - // If nil, http.DefaultTransport is used. - transport http.RoundTripper - - // respCache is the cache used to store responses. - respCache *RespCache - - // markCachedResponses, if true, indicates that responses returned from the - // cache will be given an extra header, X-From-Cache - markCachedResponses bool - - InsecureSkipVerify bool - - userAgent string - - disableCaching bool -} - -// NewTransport returns a new Transport with the provided Cache and options. If -// KeyFunc is not specified in opts then DefaultKeyFunc is used. -func NewTransport(cacheDir string, opts ...Opt) *Transport { - t := &Transport{ - markCachedResponses: true, - disableCaching: false, - InsecureSkipVerify: false, - } - for _, opt := range opts { - opt(t) - } - - if !t.disableCaching { - t.respCache = NewRespCache(cacheDir) - } - return t -} - -type Handler struct { - Cached func(cachedFilepath string) error - Uncached func() (wc io.WriteCloser, errFn func(error), err error) - Error func(err error) -} - -func (t *Transport) Fetch(ctx context.Context, dlURL string, h Handler) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, dlURL, nil) - if err != nil { - h.Error(err) - return - } - if t.userAgent != "" { - req.Header.Set("User-Agent", t.userAgent) - } - - t.FetchWith(req, h) -} - -func (t *Transport) FetchWith(req *http.Request, cb Handler) { - ctx := req.Context() - log := lg.FromContext(ctx) - log.Info("Fetching download", lga.URL, req.URL.String()) - _ = log - _, fpBody := t.respCache.Paths(req) - - if t.IsFresh(req) { - _ = cb.Cached(fpBody) - return - } - - var err error - cacheable := t.isCacheable(req) - var cachedResp *http.Response - if cacheable { - cachedResp, err = t.respCache.Get(req.Context(), req) - } else { - // Need to invalidate an existing value - if err = t.respCache.Delete(req.Context()); err != nil { - cb.Error(err) - return - } - } - - transport := t.transport - if transport == nil { - transport = http.DefaultTransport - } - var resp *http.Response - if cacheable && cachedResp != nil && err == nil { - if t.markCachedResponses { - cachedResp.Header.Set(XFromCache, "1") - } - - if varyMatches(cachedResp, req) { - // Can only use cached value if the new request doesn't Vary significantly - freshness := getFreshness(cachedResp.Header, req.Header) - if freshness == fresh { - _ = cb.Cached(fpBody) - return - } - - if freshness == stale { - var req2 *http.Request - // Add validators if caller hasn't already done so - etag := cachedResp.Header.Get("etag") - if etag != "" && req.Header.Get("etag") == "" { - req2 = cloneRequest(req) - req2.Header.Set("if-none-match", etag) - } - lastModified := cachedResp.Header.Get("last-modified") - if lastModified != "" && req.Header.Get("last-modified") == "" { - if req2 == nil { - req2 = cloneRequest(req) - } - req2.Header.Set("if-modified-since", lastModified) - } - if req2 != nil { - req = req2 - } - } - } - - // FIXME: Use an http client here - resp, err = transport.RoundTrip(req) - if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { - // Replace the 304 response with the one from cache, but update with some new headers - endToEndHeaders := getEndToEndHeaders(resp.Header) - for _, header := range endToEndHeaders { - cachedResp.Header[header] = resp.Header[header] - } - resp = cachedResp - } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && - req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header) { - // In case of transport failure and stale-if-error activated, returns cached content - // when available - log.Warn("Returning cached response due to transport failure", lga.Err, err) - cb.Cached(fpBody) - return - } else { - if err != nil || resp.StatusCode != http.StatusOK { - lg.WarnIfError(log, msgDeleteCache, t.respCache.Delete(req.Context())) - } - if err != nil { - cb.Error(err) - return - } - } - } else { - reqCacheControl := parseCacheControl(req.Header) - if _, ok := reqCacheControl["only-if-cached"]; ok { - resp = newGatewayTimeoutResponse(req) - } else { - resp, err = transport.RoundTrip(req) - if err != nil { - cb.Error(err) - return - } - } - } - - if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) { - for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") { - varyKey = http.CanonicalHeaderKey(varyKey) - fakeHeader := "X-Varied-" + varyKey - reqValue := req.Header.Get(varyKey) - if reqValue != "" { - resp.Header.Set(fakeHeader, reqValue) - } - } - - copyWrtr, errFn, err := cb.Uncached() - if err != nil { - cb.Error(err) - return - } - - if err = t.respCache.Write(req.Context(), resp, copyWrtr); err != nil { - log.Error("failed to write download cache", lga.Err, err) - errFn(err) - cb.Error(err) - } - return - } else { - lg.WarnIfError(log, "Delete resp cache", t.respCache.Delete(req.Context())) - } - - // It's not cacheable, so we need to write it to the copyWrtr. - copyWrtr, errFn, err := cb.Uncached() - if err != nil { - cb.Error(err) - return - } - cr := contextio.NewReader(ctx, resp.Body) - _, err = io.Copy(copyWrtr, cr) - if err != nil { - errFn(err) - cb.Error(err) - return - } - if err = copyWrtr.Close(); err != nil { - cb.Error(err) - return - } - - return -} - -func (t *Transport) getClient() *http.Client { - return ioz.NewHTTPClient(t.InsecureSkipVerify) -} - -// Delete deletes the cache. -func (t *Transport) Delete(ctx context.Context) error { - if t.respCache != nil { - return t.respCache.Delete(ctx) - } - return nil -} - -// IsCached returns true if there is a cache entry for req. This does not -// guarantee that the cache entry is fresh. See also: [Transport.IsFresh]. -func (t *Transport) IsCached(req *http.Request) bool { - if t.disableCaching { - return false - } - return t.respCache.Exists(req) -} - -// IsFresh returns true if there is a fresh cache entry for req. -func (t *Transport) IsFresh(req *http.Request) bool { - ctx := req.Context() - log := lg.FromContext(ctx) - - if !t.isCacheable(req) { - return false - } - - if !t.respCache.Exists(req) { - return false - } - - fpHeader, _ := t.respCache.Paths(req) - f, err := os.Open(fpHeader) - if err != nil { - log.Error("Failed to open cached response header file", lga.File, fpHeader, lga.Err, err) - return false - } - - defer lg.WarnIfCloseError(log, "Close cached response header", f) - - cachedResp, err := readResponseHeader(bufio.NewReader(f), nil) - if err != nil { - log.Error("Failed to read cached response", lga.Err, err) - return false - } - - freshness := getFreshness(cachedResp.Header, req.Header) - return freshness == fresh -} - -func (t *Transport) isCacheable(req *http.Request) bool { - if t.disableCaching { - return false - } - return (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get("range") == "" -} diff --git a/libsq/source/dl.go b/libsq/source/dl.go index 200c4df85..4c815371f 100644 --- a/libsq/source/dl.go +++ b/libsq/source/dl.go @@ -5,7 +5,7 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/ioz/httpcache" + "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" @@ -27,10 +27,10 @@ func newDownloader2(cacheDir, userAgent, dlURL string) (*downloader2, error) { } //dc := diskcache.NewWithDiskv(dv) - rc := httpcache.NewRespCache(cacheDir) - tp := httpcache.NewTransport(rc) + rc := download.NewRespCache(cacheDir) + tp := download.New(rc) - //respCache := httpcache.NewRespCache(cacheDir) + //respCache := download.NewRespCache(cacheDir) //tp.RespCache = respCache //tp.BodyFilepath = filepath.Join(cacheDir, "body.data") @@ -53,7 +53,7 @@ type downloader2 struct { userAgent string cacheDir string url string - tp *httpcache.Transport + tp *download.Download } func (d *downloader2) log(log *slog.Logger) *slog.Logger { @@ -65,7 +65,7 @@ func (d *downloader2) ClearCache(ctx context.Context) error { d.mu.Lock() defer d.mu.Unlock() - if err := d.tp.Delete(ctx); err != nil { + if err := d.tp.Clear(ctx); err != nil { return errz.Wrapf(err, "failed to clear cache dir: %s", d.cacheDir) } @@ -83,12 +83,6 @@ func (d *downloader2) Download(ctx context.Context, dest io.Writer) (written int req.Header.Set("User-Agent", d.userAgent) } - isCached := d.tp.IsCached(req) - _ = isCached - - isFresh := d.tp.IsFresh(req) - _ = isFresh - resp, err := d.c.Do(req) if err != nil { return written, "", errz.Wrapf(err, "download failed for: %s", d.url) @@ -126,16 +120,10 @@ func (d *downloader2) Download2(ctx context.Context, dest io.Writer) (written in req.Header.Set("User-Agent", d.userAgent) } - isCached := d.tp.IsCached(req) - _ = isCached - - isFresh := d.tp.IsFresh(req) - _ = isFresh - var gotFp string var gotErr error //buf := &bytes.Buffer{} - cb := httpcache.Handler{ + cb := download.Handler{ Cached: func(cachedFilepath string) error { gotFp = cachedFilepath return nil @@ -150,7 +138,7 @@ func (d *downloader2) Download2(ctx context.Context, dest io.Writer) (written in }, } - d.tp.FetchWith(req, cb) + d.tp.fetchWith(req, cb) _ = gotFp _ = gotErr From 11dc12de5cd679f7f91dc1e5222a50af1c3d5721 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 05:48:15 -0700 Subject: [PATCH 101/195] wip: refactoring download --- libsq/core/ioz/download/download.go | 49 +++++++++--------------- libsq/core/ioz/download/download_test.go | 14 +++---- 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 23274ed88..238cee723 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -110,8 +110,9 @@ type Download struct { } // New returns a new Download that uses cacheDir as the cache directory. -func New(cacheDir string, opts ...Opt) *Download { +func New(url, cacheDir string, opts ...Opt) *Download { t := &Download{ + url: url, markCachedResponses: true, disableCaching: false, InsecureSkipVerify: false, @@ -144,9 +145,9 @@ type Handler struct { Error func(err error) } -// Get gets the download at url, invoking h as appropriate. -func (dl *Download) Get(ctx context.Context, url string, h Handler) { - req, err := dl.newRequest(ctx, url) +// Get gets the download, invoking Handler as appropriate. +func (dl *Download) Get(ctx context.Context, h Handler) { + req, err := dl.newRequest(ctx, dl.url) if err != nil { h.Error(err) return @@ -179,10 +180,6 @@ func (dl *Download) get(req *http.Request, cb Handler) { } } - transport := dl.transport - if transport == nil { - transport = http.DefaultTransport - } var resp *http.Response if cacheable && cachedResp != nil && err == nil { if dl.markCachedResponses { @@ -218,8 +215,7 @@ func (dl *Download) get(req *http.Request, cb Handler) { } } - // FIXME: Use an http client here - resp, err = transport.RoundTrip(req) + resp, err = dl.execRequest(req) if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { // Replace the 304 response with the one from cache, but update with some new headers endToEndHeaders := getEndToEndHeaders(resp.Header) @@ -248,7 +244,7 @@ func (dl *Download) get(req *http.Request, cb Handler) { if _, ok := reqCacheControl["only-if-cached"]; ok { resp = newGatewayTimeoutResponse(req) } else { - resp, err = transport.RoundTrip(req) + resp, err = dl.execRequest(req) if err != nil { cb.Error(err) return @@ -313,6 +309,14 @@ func (dl *Download) Close() error { return nil } +// execRequest executes the request. +func (dl *Download) execRequest(req *http.Request) (*http.Response, error) { + if dl.transport == nil { + return http.DefaultTransport.RoundTrip(req) + } + return dl.transport.RoundTrip(req) +} + func (dl *Download) newRequest(ctx context.Context, url string) (*http.Request, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -337,26 +341,9 @@ func (dl *Download) Clear(ctx context.Context) error { return nil } -// IsCached returns true if there is a cache entry for url. Success does not -// guarantee that the cache entry is fresh. See also: [Download.IsFresh]. -func (dl *Download) IsCached(ctx context.Context, url string) bool { - req, err := dl.newRequest(ctx, url) - if err != nil { - return false - } - return dl.isCached(req) -} - -func (dl *Download) isCached(req *http.Request) bool { - if dl.disableCaching { - return false - } - return dl.respCache.Exists(req) -} - -// State returns the cache state of url. -func (dl *Download) State(ctx context.Context, url string) State { - req, err := dl.newRequest(ctx, url) +// State returns the Download's cache state. +func (dl *Download) State(ctx context.Context) State { + req, err := dl.newRequest(ctx, dl.url) if err != nil { return Uncached } diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index f1700df9e..388349b89 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -31,7 +31,7 @@ func TestDownload(t *testing.T) { require.NoError(t, err) t.Logf("cacheDir: %s", cacheDir) - dl := download.New(cacheDir, download.OptUserAgent("sq/dev")) + dl := download.New(dlURL, cacheDir, download.OptUserAgent("sq/dev")) require.NoError(t, dl.Clear(ctx)) var ( @@ -60,16 +60,16 @@ func TestDownload(t *testing.T) { }, } - require.Equal(t, download.Uncached, dl.State(ctx, dlURL)) - dl.Get(ctx, dlURL, h) + require.Equal(t, download.Uncached, dl.State(ctx)) + dl.Get(ctx, h) require.NoError(t, gotErr) require.Empty(t, gotFp) require.Equal(t, sizeActorCSV, int64(destBuf.Len())) - require.Equal(t, download.Fresh, dl.State(ctx, dlURL)) + require.Equal(t, download.Fresh, dl.State(ctx)) reset() - dl.Get(ctx, dlURL, h) + dl.Get(ctx, h) require.NoError(t, gotErr) require.Equal(t, 0, destBuf.Len()) require.NotEmpty(t, gotFp) @@ -77,8 +77,8 @@ func TestDownload(t *testing.T) { require.NoError(t, err) require.Equal(t, sizeActorCSV, int64(len(gotFileBytes))) - require.Equal(t, download.Fresh, dl.State(ctx, dlURL)) + require.Equal(t, download.Fresh, dl.State(ctx)) require.NoError(t, dl.Clear(ctx)) - require.Equal(t, download.Uncached, dl.State(ctx, dlURL)) + require.Equal(t, download.Uncached, dl.State(ctx)) } From b71e332354a35ef15a73b4457d72a6c04301583f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 08:00:21 -0700 Subject: [PATCH 102/195] refactor ioz.NewHTTPClient --- libsq/core/ioz/download/download.go | 9 +- libsq/core/ioz/download/download_test.go | 9 +- libsq/core/ioz/download/respcache.go | 17 +-- libsq/core/ioz/httpz.go | 124 +++++++++++++++++++ libsq/core/ioz/httpz_test.go | 150 +++++++++++++++++++++++ libsq/core/ioz/ioz.go | 46 +++---- libsq/source/dl.go | 23 ++-- libsq/source/dl_test.go | 19 +-- libsq/source/download.go | 4 +- libsq/source/download_test.go | 3 +- 10 files changed, 337 insertions(+), 67 deletions(-) create mode 100644 libsq/core/ioz/httpz.go create mode 100644 libsq/core/ioz/httpz_test.go diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 238cee723..a5b2361ae 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -11,13 +11,14 @@ package download import ( "bufio" "context" + "io" + "net/http" + "os" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" - "io" - "net/http" - "os" ) // State is an enumeration of caching states based on the cache-control @@ -330,7 +331,7 @@ func (dl *Download) newRequest(ctx context.Context, url string) (*http.Request, } func (dl *Download) getClient() *http.Client { - return ioz.NewHTTPClient(dl.InsecureSkipVerify) + return ioz.NewHTTPClient("", dl.InsecureSkipVerify, 0, 0) } // Clear deletes the cache. diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index 388349b89..117eca01a 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -3,15 +3,16 @@ package download_test import ( "bytes" "context" + "io" + "os" + "path/filepath" + "testing" + "github.com/neilotoole/slogt" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/sq/libsq/core/lg" "github.com/stretchr/testify/require" - "io" - "os" - "path/filepath" - "testing" ) const ( diff --git a/libsq/core/ioz/download/respcache.go b/libsq/core/ioz/download/respcache.go index 2aa88eab1..fd79b2be7 100644 --- a/libsq/core/ioz/download/respcache.go +++ b/libsq/core/ioz/download/respcache.go @@ -4,18 +4,19 @@ import ( "bufio" "bytes" "context" - "github.com/neilotoole/sq/libsq/core/cleanup" - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" "io" "net/http" "net/http/httputil" "os" "path/filepath" "sync" + + "github.com/neilotoole/sq/libsq/core/cleanup" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" ) // NewRespCache returns a new instance that stores responses in cacheDir. @@ -23,8 +24,8 @@ import ( func NewRespCache(cacheDir string) *RespCache { c := &RespCache{ Dir: cacheDir, - //Header: filepath.Join(cacheDir, "header"), - //Body: filepath.Join(cacheDir, "body"), + // Header: filepath.Join(cacheDir, "header"), + // Body: filepath.Join(cacheDir, "body"), clnup: cleanup.New(), } return c diff --git a/libsq/core/ioz/httpz.go b/libsq/core/ioz/httpz.go new file mode 100644 index 000000000..f66736f83 --- /dev/null +++ b/libsq/core/ioz/httpz.go @@ -0,0 +1,124 @@ +package ioz + +import ( + "context" + "crypto/tls" + "net/http" + "time" + + "github.com/neilotoole/sq/libsq/core/errz" +) + +// NewHTTPClient returns a new HTTP client. If userAgent is non-empty, the +// "User-Agent" header is applied to each request. If insecureSkipVerify is +// true, the client will skip TLS verification. If headerTimeout > 0, a +// timeout is applied to receiving the HTTP response, but that timeout is +// not applied to reading the response body. This is useful if you expect +// a response within, say, 5 seconds, but you expect the body to take longer +// to read. If bodyTimeout > 0, it is applied to the total lifecycle of +// the request and response, including reading the response body. +func NewHTTPClient(userAgent string, insecureSkipVerify bool, + headerTimeout, bodyTimeout time.Duration, +) *http.Client { + c := *http.DefaultClient + var tr *http.Transport + if c.Transport == nil { + tr = (http.DefaultTransport.(*http.Transport)).Clone() + } else { + tr = (c.Transport.(*http.Transport)).Clone() + } + + if tr.TLSClientConfig == nil { + // We allow tls.VersionTLS10, even though it's not considered + // secure these days. Ultimately this could become a config + // option. + tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS10} //nolint:gosec + } else { + tr.TLSClientConfig = tr.TLSClientConfig.Clone() + tr.TLSClientConfig.MinVersion = tls.VersionTLS10 //nolint:gosec + } + + tr.TLSClientConfig.InsecureSkipVerify = insecureSkipVerify + c.Transport = tr + if userAgent != "" { + c.Transport = &userAgentRoundTripper{ + userAgent: userAgent, + rt: c.Transport, + } + } + + c.Timeout = bodyTimeout + if headerTimeout > 0 { + c.Transport = &headerTimeoutRoundTripper{ + headerTimeout: headerTimeout, + rt: c.Transport, + } + } + + return &c +} + +// userAgentRoundTripper applies a User-Agent header to each request. +type userAgentRoundTripper struct { + userAgent string + rt http.RoundTripper +} + +// RoundTrip implements http.RoundTripper. +func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", rt.userAgent) + return rt.rt.RoundTrip(req) +} + +// headerTimeoutRoundTripper applies headerTimeout to the return of the http +// response, but headerTimeout is not applied to reading the body of the +// response. This is useful if you expect a response within, say, 5 seconds, +// but you expect the body to take longer to read. +type headerTimeoutRoundTripper struct { + headerTimeout time.Duration + rt http.RoundTripper +} + +// RoundTrip implements http.RoundTripper. +func (rt *headerTimeoutRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if rt.headerTimeout <= 0 { + return rt.rt.RoundTrip(req) + } + + timerCancelCh := make(chan struct{}) + ctx, cancelFn := context.WithCancelCause(req.Context()) + go func() { + t := time.NewTimer(rt.headerTimeout) + defer t.Stop() + select { + case <-ctx.Done(): + case <-t.C: + cancelFn(errz.Errorf("http response not received by %d timeout", + rt.headerTimeout)) + case <-timerCancelCh: + // Stop the timer goroutine. + } + }() + + resp, err := rt.rt.RoundTrip(req.WithContext(ctx)) + close(timerCancelCh) + + // Don't leak resources; ensure that cancelFn is eventually called. + switch { + case err != nil: + + // It's possible that cancelFn has already been called by the + // timer goroutine, but we call it again just in case. + cancelFn(err) + case resp != nil && resp.Body != nil: + + // Wrap resp.Body with a ReadCloserNotifier, so that cancelFn + // is called when the body is closed. + resp.Body = ReadCloserNotifier(resp.Body, cancelFn) + default: + // Not sure if this can actually happen, but just in case. + cancelFn(context.Canceled) + } + + return resp, err +} diff --git a/libsq/core/ioz/httpz_test.go b/libsq/core/ioz/httpz_test.go new file mode 100644 index 000000000..cf4c6d24a --- /dev/null +++ b/libsq/core/ioz/httpz_test.go @@ -0,0 +1,150 @@ +package ioz_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/stretchr/testify/assert" + + "github.com/stretchr/testify/require" +) + +func TestNewHTTPClient_headerTimeout(t *testing.T) { + t.Parallel() + const ( + headerTimeout = time.Second * 2 + numLines = 7 + ) + + testCases := []struct { + name string + ctxFn func(t *testing.T) context.Context + c *http.Client + wantErr bool + }{ + { + name: "http.DefaultClient", + ctxFn: func(t *testing.T) context.Context { + ctx, cancelFn := context.WithTimeout(context.Background(), headerTimeout) + t.Cleanup(cancelFn) + return ctx + }, + c: http.DefaultClient, + wantErr: true, + }, + { + name: "headerTimeout", + ctxFn: func(t *testing.T) context.Context { + return context.Background() + }, + c: ioz.NewHTTPClient("", false, headerTimeout, 0), + wantErr: false, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for i := 0; i < numLines; i++ { + select { + case <-r.Context().Done(): + t.Logf("Server exiting due to: %v", r.Context().Err()) + return + default: + } + if _, err := io.WriteString(w, string(rune('A'+i))+"\n"); err != nil { + t.Logf("Server write err: %v", err) + return + } + w.(http.Flusher).Flush() + time.Sleep(time.Second) + } + })) + t.Cleanup(slowServer.Close) + + ctx := tc.ctxFn(t) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, slowServer.URL, nil) + require.NoError(t, err) + + resp, err := tc.c.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Sleep long enough to trigger the header timeout. + time.Sleep(headerTimeout + time.Second) + b, err := io.ReadAll(resp.Body) + if tc.wantErr { + require.Error(t, err) + t.Logf("err: %T: %v", err, err) + return + } + + require.NoError(t, err) + require.Len(t, b, numLines*2) // *2 because of the newlines. + }) + } +} + +func TestTimeout1(t *testing.T) { + const urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" + + const urlActorCSV = "https://sq.io/testdata/actor.csv" + const respTimeout = time.Second * 2 + const lines = 10 + const wantLen = lines * 2 + slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for i := 0; i < lines; i++ { + select { + case <-r.Context().Done(): + t.Logf("Server exiting due to: %v", r.Context().Err()) + return + default: + } + _, _ = io.WriteString(w, string(rune('A'+i))+"\n") + w.(http.Flusher).Flush() + time.Sleep(time.Second) + } + })) + t.Cleanup(slowServer.Close) + + ctx, cancelFn := context.WithTimeout(context.Background(), respTimeout) + defer cancelFn() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, slowServer.URL, nil) + require.NoError(t, err) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // cancelFn() + time.Sleep(time.Second * 3) + + select { + case <-ctx.Done(): + t.Logf("ctx is done: %v", ctx.Err()) + default: + t.Logf("ctx is not done") + cancelFn() + } + + // cancelFn() + b, err := io.ReadAll(resp.Body) + t.Logf("err: %T: %v", err, err) + t.Logf("len(b): %d", len(b)) + t.Logf("b:\n\n%s\n\n", b) + assert.Error(t, err) + // require.Nil(t, b) + _ = b + // require.Len(t, b, 0) + // require.Len(t, b, 7641) +} diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 118aa399a..02998342c 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -5,17 +5,16 @@ import ( "bytes" "context" crand "crypto/rand" - "crypto/tls" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" "io" mrand "math/rand" - "net/http" "os" "path/filepath" "strings" "sync" "time" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/a8m/tree" "github.com/a8m/tree/ostree" yaml "github.com/goccy/go-yaml" @@ -390,34 +389,25 @@ func PrintTree(w io.Writer, loc string, showSize, colorize bool) error { return nil } -// NewHTTPClient returns a new HTTP client with no client-wide timeout. -// If a timeout is needed, use a [context.WithTimeout] for each request. -// If insecureSkipVerify is true, the client will skip TLS verification. -func NewHTTPClient(insecureSkipVerify bool) *http.Client { - client := *http.DefaultClient - - var tr *http.Transport - if client.Transport == nil { - tr = (http.DefaultTransport.(*http.Transport)).Clone() - } else { - tr = (client.Transport.(*http.Transport)).Clone() - } - - if tr.TLSClientConfig == nil { - // We allow tls.VersionTLS10, even though it's not considered - // secure these days. Ultimately this could become a config - // option. - tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS10} //nolint:gosec - } else { - tr.TLSClientConfig = tr.TLSClientConfig.Clone() +// ReadCloserNotifier returns a new io.ReadCloser that invokes fn +// after Close is called, passing along any error from Close. +// If rc or fn is nil, rc is returned. +func ReadCloserNotifier(rc io.ReadCloser, fn func(closeErr error)) io.ReadCloser { + if rc == nil || fn == nil { + return rc } + return &readCloseNotifier{ReadCloser: rc, fn: fn} +} - tr.TLSClientConfig.InsecureSkipVerify = insecureSkipVerify - - client.Timeout = 0 - client.Transport = tr +type readCloseNotifier struct { + fn func(error) + io.ReadCloser +} - return &client +func (c *readCloseNotifier) Close() error { + err := c.Close() + c.fn(err) + return err } // WriteToFile writes the contents of r to fp. If fp doesn't exist, diff --git a/libsq/source/dl.go b/libsq/source/dl.go index 4c815371f..067e622e1 100644 --- a/libsq/source/dl.go +++ b/libsq/source/dl.go @@ -2,6 +2,11 @@ package source import ( "context" + "io" + "log/slog" + "net/http" + "sync" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/contextio" @@ -9,10 +14,6 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" - "io" - "log/slog" - "net/http" - "sync" ) // newDownloader creates a new downloader using cacheDir for the given url. @@ -26,20 +27,20 @@ func newDownloader2(cacheDir, userAgent, dlURL string) (*downloader2, error) { return nil, err } - //dc := diskcache.NewWithDiskv(dv) + // dc := diskcache.NewWithDiskv(dv) rc := download.NewRespCache(cacheDir) tp := download.New(rc) - //respCache := download.NewRespCache(cacheDir) - //tp.RespCache = respCache - //tp.BodyFilepath = filepath.Join(cacheDir, "body.data") + // respCache := download.NewRespCache(cacheDir) + // tp.RespCache = respCache + // tp.BodyFilepath = filepath.Join(cacheDir, "body.data") c := &http.Client{Transport: tp} return &downloader2{ c: c, - //dc: dc, - //dv: dv, + // dc: dc, + // dv: dv, cacheDir: cacheDir, url: dlURL, userAgent: userAgent, @@ -122,7 +123,7 @@ func (d *downloader2) Download2(ctx context.Context, dest io.Writer) (written in var gotFp string var gotErr error - //buf := &bytes.Buffer{} + // buf := &bytes.Buffer{} cb := download.Handler{ Cached: func(cachedFilepath string) error { gotFp = cachedFilepath diff --git a/libsq/source/dl_test.go b/libsq/source/dl_test.go index e2c8adb61..7326c0780 100644 --- a/libsq/source/dl_test.go +++ b/libsq/source/dl_test.go @@ -3,13 +3,14 @@ package source import ( "bytes" "context" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/stretchr/testify/require" "net/url" "path" "path/filepath" "testing" + + "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/stretchr/testify/require" ) func TestDownloader2_Download(t *testing.T) { @@ -29,21 +30,21 @@ func TestDownloader2_Download(t *testing.T) { log.Debug("huzzah") dl, err := newDownloader2(cacheDir, "sq/dev", dlURL) require.NoError(t, err) - //require.NoError(t, dl.ClearCache(ctx)) + // require.NoError(t, dl.ClearCache(ctx)) buf := &bytes.Buffer{} written, cachedFp, err := dl.Download2(ctx, buf) _ = written _ = cachedFp require.NoError(t, err) - //require.Equal(t, wantContentLength, written) - //require.Equal(t, wantContentLength, int64(buf.Len())) + // require.Equal(t, wantContentLength, written) + // require.Equal(t, wantContentLength, int64(buf.Len())) buf.Reset() written, cachedFp, err = dl.Download2(ctx, buf) require.NoError(t, err) - //require.Equal(t, wantContentLength, written) - //require.Equal(t, wantContentLength, int64(buf.Len())) + // require.Equal(t, wantContentLength, written) + // require.Equal(t, wantContentLength, int64(buf.Len())) } func TestDownloader2_Download_Legacy(t *testing.T) { @@ -61,7 +62,7 @@ func TestDownloader2_Download_Legacy(t *testing.T) { dl, err := newDownloader2(cacheDir, "sq/dev", dlURL) require.NoError(t, err) - //require.NoError(t, dl.ClearCache(ctx)) + // require.NoError(t, dl.ClearCache(ctx)) buf := &bytes.Buffer{} written, cachedFp, err := dl.Download2(ctx, buf) diff --git a/libsq/source/download.go b/libsq/source/download.go index 64d377f84..27ea099f7 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -156,7 +156,7 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 if err != nil { return written, "", errz.Wrapf(err, "download new request failed for: %s", d.url) } - //setDefaultHTTPRequestHeaders(req) + // setDefaultHTTPRequestHeaders(req) resp, err := d.c.Do(req) if err != nil { @@ -364,7 +364,7 @@ func fetchHTTPResponse(ctx context.Context, c *http.Client, u string) (resp *htt if err != nil { return nil, errz.Err(err) } - //setDefaultHTTPRequestHeaders(req) + // setDefaultHTTPRequestHeaders(req) resp, err = http.DefaultClient.Do(req) if err != nil { diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go index a9093478e..bbba5da49 100644 --- a/libsq/source/download_test.go +++ b/libsq/source/download_test.go @@ -3,7 +3,6 @@ package source import ( "bytes" "context" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" "net/http" "net/http/httptest" "net/url" @@ -12,6 +11,8 @@ import ( "strconv" "testing" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" From b61560b7f9b09a43dc61cc15c1c3787ac3b62ff7 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 08:12:46 -0700 Subject: [PATCH 103/195] more refactoring --- libsq/core/ioz/download/download.go | 27 +++++++++++++----------- libsq/core/ioz/download/download_test.go | 6 +++++- libsq/source/dl.go | 2 +- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index a5b2361ae..fc12c3d9a 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -15,7 +15,6 @@ import ( "net/http" "os" - "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" @@ -88,10 +87,11 @@ func OptUserAgent(userAgent string) Opt { type Download struct { // FIXME: Does Download need a sync.Mutex? - // FIXME: implement url mechanism // url is the URL of the download. url string + c *http.Client + // The RoundTripper interface actually used to make requests // If nil, http.DefaultTransport is used. transport http.RoundTripper @@ -110,9 +110,15 @@ type Download struct { disableCaching bool } -// New returns a new Download that uses cacheDir as the cache directory. -func New(url, cacheDir string, opts ...Opt) *Download { +// New returns a new Download for url that writes to cacheDir. +// If c is nil, http.DefaultClient is used. +func New(c *http.Client, url, cacheDir string, opts ...Opt) *Download { + if c == nil { + c = http.DefaultClient + } + t := &Download{ + c: c, url: url, markCachedResponses: true, disableCaching: false, @@ -312,10 +318,11 @@ func (dl *Download) Close() error { // execRequest executes the request. func (dl *Download) execRequest(req *http.Request) (*http.Response, error) { - if dl.transport == nil { - return http.DefaultTransport.RoundTrip(req) - } - return dl.transport.RoundTrip(req) + //if dl.transport == nil { + // return http.DefaultTransport.RoundTrip(req) + //} + //return dl.transport.RoundTrip(req) + return dl.c.Do(req) } func (dl *Download) newRequest(ctx context.Context, url string) (*http.Request, error) { @@ -330,10 +337,6 @@ func (dl *Download) newRequest(ctx context.Context, url string) (*http.Request, return req, nil } -func (dl *Download) getClient() *http.Client { - return ioz.NewHTTPClient("", dl.InsecureSkipVerify, 0, 0) -} - // Clear deletes the cache. func (dl *Download) Clear(ctx context.Context) error { if dl.respCache != nil { diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index 117eca01a..b80cdfe92 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -22,6 +22,10 @@ const ( sizeGzipActorCSV = int64(1968) ) +func TestDownload_redirect(t *testing.T) { + +} + func TestDownload(t *testing.T) { log := slogt.New(t) ctx := lg.NewContext(context.Background(), log) @@ -32,7 +36,7 @@ func TestDownload(t *testing.T) { require.NoError(t, err) t.Logf("cacheDir: %s", cacheDir) - dl := download.New(dlURL, cacheDir, download.OptUserAgent("sq/dev")) + dl := download.New(nil, dlURL, cacheDir, download.OptUserAgent("sq/dev")) require.NoError(t, dl.Clear(ctx)) var ( diff --git a/libsq/source/dl.go b/libsq/source/dl.go index 067e622e1..a8b9ab18a 100644 --- a/libsq/source/dl.go +++ b/libsq/source/dl.go @@ -29,7 +29,7 @@ func newDownloader2(cacheDir, userAgent, dlURL string) (*downloader2, error) { // dc := diskcache.NewWithDiskv(dv) rc := download.NewRespCache(cacheDir) - tp := download.New(rc) + tp := download.New(nil, rc) // respCache := download.NewRespCache(cacheDir) // tp.RespCache = respCache From ee0281124aacc77ee94ef19609a1b813cde421d5 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 11:17:26 -0700 Subject: [PATCH 104/195] download seems closer --- go.mod | 2 - go.sum | 4 - libsq/core/ioz/download/download.go | 64 +++++--- libsq/core/ioz/download/download_test.go | 167 +++++++++++++++++++- libsq/core/ioz/download/respcache.go | 28 +++- libsq/core/ioz/download/testdata/.gitignore | 1 + libsq/core/ioz/ioz.go | 6 +- 7 files changed, 235 insertions(+), 37 deletions(-) create mode 100644 libsq/core/ioz/download/testdata/.gitignore diff --git a/go.mod b/go.mod index 0f011723e..d69ab6dc7 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,6 @@ require ( github.com/neilotoole/slogt v1.1.0 github.com/nightlyone/lockfile v1.0.0 github.com/otiai10/copy v1.14.0 - github.com/peterbourgon/diskv v2.0.1+incompatible github.com/ryboe/q v1.0.20 github.com/samber/lo v1.39.0 github.com/segmentio/encoding v0.3.7 @@ -71,7 +70,6 @@ require ( github.com/djherbis/atime v1.1.0 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect - github.com/google/btree v1.0.1 // indirect github.com/huandu/xstrings v1.3.3 // indirect github.com/imdario/mergo v0.3.11 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/go.sum b/go.sum index 11ea035fe..6866c715d 100644 --- a/go.sum +++ b/go.sum @@ -63,8 +63,6 @@ github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0kt github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.1.0 h1:ZCD6MBpcuOVfGVqsEmY5/4FtYiKz6tSyUv9LPEDei6A= github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EOqtpKwwwHI= -github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= -github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -139,8 +137,6 @@ github.com/otiai10/copy v1.14.0 h1:dCI/t1iTdYGtkvCuBG2BgR6KZa83PTclw4U5n2wAllU= github.com/otiai10/copy v1.14.0/go.mod h1:ECfuL02W+/FkTWZWgQqXPWZgW9oeKCSQ5qVfSc4qc4w= github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks= github.com/otiai10/mint v1.5.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= -github.com/peterbourgon/diskv v2.0.1+incompatible h1:UBdAOUP5p4RWqPBg048CAvpKN+vxiaj6gdUUzhl4XmI= -github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index fc12c3d9a..74f2f95f9 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -11,9 +11,12 @@ package download import ( "bufio" "context" + "github.com/neilotoole/sq/libsq/core/errz" "io" "net/http" + "net/url" "os" + "path/filepath" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" @@ -112,14 +115,22 @@ type Download struct { // New returns a new Download for url that writes to cacheDir. // If c is nil, http.DefaultClient is used. -func New(c *http.Client, url, cacheDir string, opts ...Opt) *Download { +func New(c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) { + _, err := url.ParseRequestURI(dlURL) + if err != nil { + return nil, err + } if c == nil { c = http.DefaultClient } + if cacheDir, err = filepath.Abs(cacheDir); err != nil { + return nil, errz.Err(err) + } + t := &Download{ c: c, - url: url, + url: dlURL, markCachedResponses: true, disableCaching: false, InsecureSkipVerify: false, @@ -131,7 +142,8 @@ func New(c *http.Client, url, cacheDir string, opts ...Opt) *Download { if !t.disableCaching { t.respCache = NewRespCache(cacheDir) } - return t + + return t, nil } // Handler is a callback invoked by Download.Get. Exactly one of the @@ -163,14 +175,15 @@ func (dl *Download) Get(ctx context.Context, h Handler) { dl.get(req, h) } -func (dl *Download) get(req *http.Request, cb Handler) { +func (dl *Download) get(req *http.Request, h Handler) { ctx := req.Context() log := lg.FromContext(ctx) log.Info("Fetching download", lga.URL, req.URL.String()) _, fpBody := dl.respCache.Paths(req) - if dl.state(req) == Fresh { - cb.Cached(fpBody) + state := dl.state(req) + if state == Fresh { + h.Cached(fpBody) return } @@ -182,7 +195,7 @@ func (dl *Download) get(req *http.Request, cb Handler) { } else { // Need to invalidate an existing value if err = dl.respCache.Clear(req.Context()); err != nil { - cb.Error(err) + h.Error(err) return } } @@ -197,7 +210,7 @@ func (dl *Download) get(req *http.Request, cb Handler) { // Can only use cached value if the new request doesn't Vary significantly freshness := getFreshness(cachedResp.Header, req.Header) if freshness == Fresh { - cb.Cached(fpBody) + h.Cached(fpBody) return } @@ -235,14 +248,14 @@ func (dl *Download) get(req *http.Request, cb Handler) { // In case of transport failure and stale-if-error activated, returns cached content // when available log.Warn("Returning cached response due to transport failure", lga.Err, err) - cb.Cached(fpBody) + h.Cached(fpBody) return } else { if err != nil || resp.StatusCode != http.StatusOK { lg.WarnIfError(log, msgDeleteCache, dl.respCache.Clear(req.Context())) } if err != nil { - cb.Error(err) + h.Error(err) return } } @@ -253,7 +266,7 @@ func (dl *Download) get(req *http.Request, cb Handler) { } else { resp, err = dl.execRequest(req) if err != nil { - cb.Error(err) + h.Error(err) return } } @@ -269,16 +282,30 @@ func (dl *Download) get(req *http.Request, cb Handler) { } } - copyWrtr, errFn := cb.Uncached() + if resp == cachedResp { + lg.WarnIfCloseError(log, "Close response body", resp.Body) + if err = dl.respCache.Write(ctx, resp, true, nil); err != nil { + log.Error("Failed to update cache header", lga.Dir, dl.respCache.Dir, lga.Err, err) + // FIXME: Should we error here, or just return the cached file? + h.Error(err) + return + } + h.Cached(fpBody) + return + } + + // I'm not sure if this logic is even reachable? + copyWrtr, errFn := h.Uncached() if copyWrtr == nil { log.Warn("nil copy writer from download handler; returning") return } - if err = dl.respCache.Write(req.Context(), resp, copyWrtr); err != nil { - log.Error("failed to write download cache", lga.Err, err) + defer lg.WarnIfCloseError(log, "Close response body", resp.Body) + if err = dl.respCache.Write(req.Context(), resp, false, copyWrtr); err != nil { + log.Error("failed to write download cache", lga.Dir, dl.respCache.Dir, lga.Err, err) errFn(err) - cb.Error(err) + h.Error(err) } return } else { @@ -286,21 +313,22 @@ func (dl *Download) get(req *http.Request, cb Handler) { } // It's not cacheable, so we need to write it to the copyWrtr. - copyWrtr, errFn := cb.Uncached() + copyWrtr, errFn := h.Uncached() if copyWrtr == nil { log.Warn("nil copy writer from download handler; returning") return } cr := contextio.NewReader(ctx, resp.Body) + defer lg.WarnIfCloseError(log, "Close response body", resp.Body) _, err = io.Copy(copyWrtr, cr) if err != nil { errFn(err) - cb.Error(err) + h.Error(err) return } if err = copyWrtr.Close(); err != nil { - cb.Error(err) + h.Error(err) return } diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index b80cdfe92..d1d2946f7 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -3,10 +3,18 @@ package download_test import ( "bytes" "context" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/testh/tu" + "github.com/stretchr/testify/assert" "io" + "log/slog" + "net/http" + "net/http/httptest" "os" "path/filepath" + "sync" "testing" + "time" "github.com/neilotoole/slogt" "github.com/neilotoole/sq/libsq/core/ioz" @@ -23,20 +31,125 @@ const ( ) func TestDownload_redirect(t *testing.T) { + const hello = `Hello World!` + var serveBody = hello + lastModified := time.Now().UTC() + //cacheDir := t.TempDir() + cacheDir := filepath.Join("testdata", "download", tu.Name(t.Name())) + log := slogt.New(t) + var srvr *httptest.Server + srvr = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log := log.With("origin", "server") + log.Info("Serving actual") + switch r.URL.Path { + case "/redirect": + loc := srvr.URL + "/actual" + log.Info("Redirecting to", lga.Loc, loc) + http.Redirect(w, r, loc, http.StatusFound) + case "/actual": + if ifm := r.Header.Get("If-Modified-Since"); ifm != "" { + tm, err := time.Parse(http.TimeFormat, ifm) + if err != nil { + log.Error("Failed to parse If-Modified-Since", lga.Err, err) + w.WriteHeader(http.StatusBadRequest) + return + } + + ifModifiedSinceUnix := tm.Unix() + lastModifiedUnix := lastModified.Unix() + + if lastModifiedUnix <= ifModifiedSinceUnix { + log.Info("Serving http.StatusNotModified") + w.WriteHeader(http.StatusNotModified) + return + } + } + + log.Info("Serving actual: writing bytes") + b := []byte(serveBody) + w.Header().Set("Last-Modified", lastModified.Format(http.TimeFormat)) + _, err := w.Write(b) + assert.NoError(t, err) + default: + log.Info("Serving http.StatusNotFound") + w.WriteHeader(http.StatusNotFound) + } + })) + t.Cleanup(srvr.Close) + + ctx := lg.NewContext(context.Background(), log.With("origin", "downloader")) + loc := srvr.URL + "/redirect" + + dl, err := download.New(nil, loc, cacheDir) + require.NoError(t, err) + require.NoError(t, dl.Clear(ctx)) + h := newTestHandler(log.With("origin", "handler")) + + dl.Get(ctx, h.Handler) + require.Empty(t, h.errors) + gotBody := h.bufs[0].String() + require.Equal(t, hello, gotBody) + + h.reset() + dl.Get(ctx, h.Handler) + require.Empty(t, h.errors) + require.Empty(t, h.bufs) + gotFile := h.cacheFiles[0] + t.Logf("got fp: %s", gotFile) + gotBody = tu.ReadFileToString(t, gotFile) + t.Logf("got body: \n\n%s\n\n", gotBody) + require.Equal(t, serveBody, gotBody) + + h.reset() + dl.Get(ctx, h.Handler) + require.Empty(t, h.errors) + require.Empty(t, h.bufs) + gotFile = h.cacheFiles[0] + t.Logf("got fp: %s", gotFile) + gotBody = tu.ReadFileToString(t, gotFile) + t.Logf("got body: \n\n%s\n\n", gotBody) + require.Equal(t, serveBody, gotBody) } +//tr := httpcache.NewTransport(diskcache.New(cacheDir)) +//req, err := http.NewRequestWithContext(ctx, http.MethodGet, loc, nil) +//require.NoError(t, err) +// +//resp, err := tr.RoundTrip(req) +//require.NoError(t, err) +//require.Equal(t, http.StatusOK, resp.StatusCode) +//b, err := io.ReadAll(resp.Body) +//require.NoError(t, err) +//require.Equal(t, serveBody, string(b)) +//t.Logf("b: \n\n%s\n\n", b) +// +//resp2, err := tr.RoundTrip(req) +//require.NoError(t, err) +//require.Equal(t, http.StatusOK, resp2.StatusCode) +// +//b, err = io.ReadAll(resp.Body) +//require.NoError(t, err) +//require.Equal(t, serveBody, string(b)) +//t.Logf("b: \n\n%s\n\n", b) + +// +//ctx := lg.NewContext(context.Background(), log.With("origin", "downloader")) +//loc := srvr.URL + "/redirect" +//loc := srvr.URL + "/actual" + func TestDownload(t *testing.T) { log := slogt.New(t) ctx := lg.NewContext(context.Background(), log) const dlURL = urlActorCSV // FIXME: switch to temp dir - cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-2")) + cacheDir, err := filepath.Abs(filepath.Join("testdata", "download", tu.Name(t.Name()))) require.NoError(t, err) t.Logf("cacheDir: %s", cacheDir) - dl := download.New(nil, dlURL, cacheDir, download.OptUserAgent("sq/dev")) + dl, err := download.New(nil, dlURL, cacheDir, download.OptUserAgent("sq/dev")) + require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) var ( @@ -87,3 +200,53 @@ func TestDownload(t *testing.T) { require.NoError(t, dl.Clear(ctx)) require.Equal(t, download.Uncached, dl.State(ctx)) } + +type testHandler struct { + download.Handler + mu sync.Mutex + log *slog.Logger + errors []error + cacheFiles []string + bufs []*bytes.Buffer + writeErrs []error +} + +func (th *testHandler) reset() { + th.mu.Lock() + defer th.mu.Unlock() + th.errors = nil + th.cacheFiles = nil + th.bufs = nil + th.writeErrs = nil +} + +func newTestHandler(log *slog.Logger) *testHandler { + th := &testHandler{log: log} + th.Cached = func(fp string) { + log.Info("Cached", lga.File, fp) + th.mu.Lock() + defer th.mu.Unlock() + th.cacheFiles = append(th.cacheFiles, fp) + } + + th.Uncached = func() (io.WriteCloser, func(error)) { + log.Info("Uncached") + th.mu.Lock() + defer th.mu.Unlock() + buf := &bytes.Buffer{} + th.bufs = append(th.bufs, buf) + return ioz.WriteCloser(buf), func(err error) { + th.mu.Lock() + defer th.mu.Unlock() + th.writeErrs = append(th.writeErrs, err) + } + } + + th.Error = func(err error) { + log.Info("Error", lga.Err, err) + th.mu.Lock() + defer th.mu.Unlock() + th.errors = append(th.errors, err) + } + return th +} diff --git a/libsq/core/ioz/download/respcache.go b/libsq/core/ioz/download/respcache.go index fd79b2be7..c0c8f39d8 100644 --- a/libsq/core/ioz/download/respcache.go +++ b/libsq/core/ioz/download/respcache.go @@ -4,6 +4,8 @@ import ( "bufio" "bytes" "context" + "fmt" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" "io" "net/http" "net/http/httputil" @@ -14,7 +16,6 @@ import ( "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" ) @@ -123,7 +124,8 @@ func (rc *RespCache) doClear(ctx context.Context) error { cleanErr := rc.clnup.Run() rc.clnup = cleanup.New() deleteErr := errz.Wrap(os.RemoveAll(rc.Dir), "delete cache dir") - err := errz.Combine(cleanErr, deleteErr) + recreateErr := ioz.RequireDir(rc.Dir) + err := errz.Combine(cleanErr, deleteErr, recreateErr) if err != nil { lg.FromContext(ctx).Error(msgDeleteCache, lga.Dir, rc.Dir, lga.Err, err) @@ -136,26 +138,31 @@ func (rc *RespCache) doClear(ctx context.Context) error { const msgDeleteCache = "Delete HTTP response cache" -// Write writes resp to the cache. If copyWrtr is non-nil, the response -// bytes are copied to that destination also. -func (rc *RespCache) Write(ctx context.Context, resp *http.Response, copyWrtr io.WriteCloser) error { +// Write writes resp to the cache. If headerOnly is true, only the header +// cache file is updated. If headerOnly is false and copyWrtr is non-nil, the +// response body bytes are copied to that destination, as well as being +// written to the cache. +func (rc *RespCache) Write(ctx context.Context, resp *http.Response, headerOnly bool, copyWrtr io.WriteCloser) error { rc.mu.Lock() defer rc.mu.Unlock() - err := rc.doWrite(ctx, resp, copyWrtr) + err := rc.doWrite(ctx, resp, headerOnly, copyWrtr) if err != nil { - lg.WarnIfError(lg.FromContext(ctx), msgDeleteCache, rc.doClear(ctx)) + lg.FromContext(ctx).Error("Failed to write HTTP response to cache", lga.Dir, rc.Dir, lga.Err, err) + + //lg.WarnIfError(lg.FromContext(ctx), msgDeleteCache, rc.doClear(ctx)) } return err } -func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, copyWrtr io.WriteCloser) error { +func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, headerOnly bool, copyWrtr io.WriteCloser) error { log := lg.FromContext(ctx) if err := ioz.RequireDir(rc.Dir); err != nil { return err } + log.Info("Writing HTTP response to cache", lga.Dir, rc.Dir, "resp", fmt.Sprintf("%v", *resp)) fpHeader, fpBody := rc.Paths(resp.Request) headerBytes, err := httputil.DumpResponse(resp, false) @@ -167,6 +174,10 @@ func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, copyWrtr return err } + if headerOnly { + return nil + } + cacheFile, err := os.OpenFile(fpBody, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) if err != nil { return err @@ -186,6 +197,7 @@ func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, copyWrtr var written int64 written, err = io.Copy(cacheFile, cr) if err != nil { + log.Error("Cache write: io.Copy failed", lga.Err, err) lg.WarnIfCloseError(log, "Close cache body file", cacheFile) return err } diff --git a/libsq/core/ioz/download/testdata/.gitignore b/libsq/core/ioz/download/testdata/.gitignore new file mode 100644 index 000000000..1fc3ed4c1 --- /dev/null +++ b/libsq/core/ioz/download/testdata/.gitignore @@ -0,0 +1 @@ +Test* diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 02998342c..ccc4ccc8c 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -420,15 +420,15 @@ func WriteToFile(ctx context.Context, fp string, r io.Reader) (written int64, er f, err := os.Create(fp) if err != nil { - return 0, err + return 0, errz.Err(err) } cr := contextio.NewReader(ctx, r) written, err = io.Copy(f, cr) closeErr := f.Close() if err == nil { - return written, closeErr + return written, errz.Err(closeErr) } - return written, err + return written, errz.Err(err) } From 7a9df888d18fdb70a777b9ad1872362858c04326 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 12:18:33 -0700 Subject: [PATCH 105/195] tidying --- libsq/core/ioz/download/download.go | 28 +---- libsq/core/ioz/download/download_test.go | 8 +- libsq/core/ioz/download/httpz.go | 93 ++++++++++++++ libsq/core/ioz/download/respcache.go | 46 +++---- libsq/core/ioz/httpz.go | 13 ++ libsq/source/dl.go | 147 ----------------------- libsq/source/dl_test.go | 79 ------------ libsq/source/download.go | 33 +---- 8 files changed, 138 insertions(+), 309 deletions(-) delete mode 100644 libsq/source/dl.go delete mode 100644 libsq/source/dl_test.go diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 74f2f95f9..c1ff18d19 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -63,13 +63,6 @@ func OptMarkCacheResponses(markCachedResponses bool) Opt { } } -// OptInsecureSkipVerify configures a Download to skip TLS verification. -func OptInsecureSkipVerify(insecureSkipVerify bool) Opt { - return func(t *Download) { - t.InsecureSkipVerify = insecureSkipVerify - } -} - // OptDisableCaching disables the cache. func OptDisableCaching(disable bool) Opt { return func(t *Download) { @@ -77,16 +70,8 @@ func OptDisableCaching(disable bool) Opt { } } -// OptUserAgent sets the User-Agent header on requests. -func OptUserAgent(userAgent string) Opt { - return func(t *Download) { - t.userAgent = userAgent - } -} - -// Download is aan implementation of http.RoundTripper that will return values from a cache -// where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since) -// to repeated requests allowing servers to return 304 / Not Modified +// Download encapsulates downloading a file from a URL, using a local +// disk cache if possible. type Download struct { // FIXME: Does Download need a sync.Mutex? @@ -95,19 +80,12 @@ type Download struct { c *http.Client - // The RoundTripper interface actually used to make requests - // If nil, http.DefaultTransport is used. - transport http.RoundTripper - - // respCache is the cache used to store responses. respCache *RespCache // markCachedResponses, if true, indicates that responses returned from the // cache will be given an extra header, X-From-Cache markCachedResponses bool - InsecureSkipVerify bool - userAgent string disableCaching bool @@ -133,7 +111,6 @@ func New(c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) url: dlURL, markCachedResponses: true, disableCaching: false, - InsecureSkipVerify: false, } for _, opt := range opts { opt(t) @@ -286,7 +263,6 @@ func (dl *Download) get(req *http.Request, h Handler) { lg.WarnIfCloseError(log, "Close response body", resp.Body) if err = dl.respCache.Write(ctx, resp, true, nil); err != nil { log.Error("Failed to update cache header", lga.Dir, dl.respCache.Dir, lga.Err, err) - // FIXME: Should we error here, or just return the cached file? h.Error(err) return } diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index d1d2946f7..d124da9ce 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -27,7 +27,6 @@ const ( urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" urlActorCSV = "https://sq.io/testdata/actor.csv" sizeActorCSV = int64(7641) - sizeGzipActorCSV = int64(1968) ) func TestDownload_redirect(t *testing.T) { @@ -35,13 +34,14 @@ func TestDownload_redirect(t *testing.T) { var serveBody = hello lastModified := time.Now().UTC() //cacheDir := t.TempDir() + // FIXME: switch back to temp dir cacheDir := filepath.Join("testdata", "download", tu.Name(t.Name())) log := slogt.New(t) var srvr *httptest.Server srvr = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log := log.With("origin", "server") - log.Info("Serving actual") + log.Info("Request on /actual", "req", download.RequestLogValue(r)) switch r.URL.Path { case "/redirect": loc := srvr.URL + "/actual" @@ -81,7 +81,7 @@ func TestDownload_redirect(t *testing.T) { ctx := lg.NewContext(context.Background(), log.With("origin", "downloader")) loc := srvr.URL + "/redirect" - dl, err := download.New(nil, loc, cacheDir) + dl, err := download.New(ioz.NewDefaultHTTPClient(), loc, cacheDir) require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) h := newTestHandler(log.With("origin", "handler")) @@ -148,7 +148,7 @@ func TestDownload(t *testing.T) { require.NoError(t, err) t.Logf("cacheDir: %s", cacheDir) - dl, err := download.New(nil, dlURL, cacheDir, download.OptUserAgent("sq/dev")) + dl, err := download.New(nil, dlURL, cacheDir) require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) diff --git a/libsq/core/ioz/download/httpz.go b/libsq/core/ioz/download/httpz.go index 84c4134bd..ed18d5c57 100644 --- a/libsq/core/ioz/download/httpz.go +++ b/libsq/core/ioz/download/httpz.go @@ -5,9 +5,14 @@ import ( "bytes" "errors" "fmt" + "github.com/neilotoole/sq/libsq/core/stringz" "io" + "log/slog" + "mime" "net/http" "net/textproto" + "path" + "path/filepath" "strconv" "strings" "time" @@ -383,3 +388,91 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { } return true } + +// ResponseLogValue implements slog.Valuer for resp. +func ResponseLogValue(resp *http.Response) slog.Value { + if resp == nil { + return slog.Value{} + } + + attrs := []slog.Attr{ + slog.String("proto", resp.Proto), + slog.String("status", resp.Status), + } + + h := resp.Header + for k, _ := range h { + vals := h.Values(k) + if len(vals) == 1 { + attrs = append(attrs, slog.String(k, vals[0])) + continue + } + + attrs = append(attrs, slog.Any(k, h.Get(k))) + } + + if resp.Request != nil { + attrs = append(attrs, slog.Any("req", RequestLogValue(resp.Request))) + } + + return slog.GroupValue(attrs...) +} + +// RequestLogValue implements slog.Valuer for req. +func RequestLogValue(req *http.Request) slog.Value { + if req == nil { + return slog.Value{} + } + + attrs := []slog.Attr{ + slog.String("method", req.Method), + slog.String("path", req.URL.RawPath), + } + + if req.Proto != "" { + attrs = append(attrs, slog.String("proto", req.Proto)) + } + if req.Host != "" { + attrs = append(attrs, slog.String("host", req.Host)) + } + + h := req.Header + for k, _ := range h { + vals := h.Values(k) + if len(vals) == 1 { + attrs = append(attrs, slog.String(k, vals[0])) + continue + } + + attrs = append(attrs, slog.Any(k, h.Get(k))) + } + + return slog.GroupValue(attrs...) +} + +// Filename returns the filename to use for a download. +// It first checks the Content-Disposition header, and if that's +// not present, it uses the last path segment of the URL. The +// filename is sanitized. +// It's possible that the returned value will be empty string; the +// caller should handle that situation themselves. +func Filename(resp *http.Response) string { + var filename string + if resp == nil || resp.Header == nil { + return "" + } + dispHeader := resp.Header.Get("Content-Disposition") + if dispHeader != "" { + if _, params, err := mime.ParseMediaType(dispHeader); err == nil { + filename = params["filename"] + } + } + + if filename == "" { + filename = path.Base(resp.Request.URL.Path) + } else { + filename = filepath.Base(filename) + } + + return stringz.SanitizeFilename(filename) +} diff --git a/libsq/core/ioz/download/respcache.go b/libsq/core/ioz/download/respcache.go index c0c8f39d8..ce1abbee3 100644 --- a/libsq/core/ioz/download/respcache.go +++ b/libsq/core/ioz/download/respcache.go @@ -4,8 +4,9 @@ import ( "bufio" "bytes" "context" - "fmt" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/lg/lgm" "io" "net/http" "net/http/httputil" @@ -138,31 +139,27 @@ func (rc *RespCache) doClear(ctx context.Context) error { const msgDeleteCache = "Delete HTTP response cache" -// Write writes resp to the cache. If headerOnly is true, only the header -// cache file is updated. If headerOnly is false and copyWrtr is non-nil, the -// response body bytes are copied to that destination, as well as being -// written to the cache. +// Write writes resp header and body to the cache. If headerOnly is true, only +// the header cache file is updated. If headerOnly is false and copyWrtr is +// non-nil, the response body bytes are copied to that destination, as well as +// being written to the cache. The response body is always closed. func (rc *RespCache) Write(ctx context.Context, resp *http.Response, headerOnly bool, copyWrtr io.WriteCloser) error { rc.mu.Lock() defer rc.mu.Unlock() - err := rc.doWrite(ctx, resp, headerOnly, copyWrtr) - if err != nil { - lg.FromContext(ctx).Error("Failed to write HTTP response to cache", lga.Dir, rc.Dir, lga.Err, err) - - //lg.WarnIfError(lg.FromContext(ctx), msgDeleteCache, rc.doClear(ctx)) - } - return err + return rc.doWrite(ctx, resp, headerOnly, copyWrtr) } -func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, headerOnly bool, copyWrtr io.WriteCloser) error { +func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, + headerOnly bool, copyWrtr io.WriteCloser) error { log := lg.FromContext(ctx) + defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) if err := ioz.RequireDir(rc.Dir); err != nil { return err } - log.Info("Writing HTTP response to cache", lga.Dir, rc.Dir, "resp", fmt.Sprintf("%v", *resp)) + log.Debug("Writing HTTP response to cache", lga.Dir, rc.Dir, "resp", ResponseLogValue(resp)) fpHeader, fpBody := rc.Paths(resp.Request) headerBytes, err := httputil.DumpResponse(resp, false) @@ -191,9 +188,6 @@ func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, headerOnl cr = contextio.NewReader(ctx, tr) } - //if copyWrtr != nil { - // cr = io.TeeReader(cr, copyWrtr) - //} var written int64 written, err = io.Copy(cacheFile, cr) if err != nil { @@ -206,15 +200,23 @@ func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, headerOnl } if err = cacheFile.Close(); err != nil { - return err + return errz.Err(err) } - log.Info("Wrote HTTP response to cache", lga.File, fpBody, lga.Size, written) - cacheFile, err = os.Open(fpBody) + sum, err := checksum.ForFile(fpBody) if err != nil { - return err + return errz.Wrap(err, "failed to compute checksum for cache body file") + } + + if err = checksum.WriteFile(filepath.Join(rc.Dir, "checksum.txt"), sum, "body"); err != nil { + return errz.Wrap(err, "failed to write checksum file for cache body") + } + + if resp.Body == nil { + resp.Body = http.NoBody + return nil } - resp.Body = cacheFile + log.Info("Wrote HTTP response body to cache", lga.Size, written, lga.File, fpBody) return nil } diff --git a/libsq/core/ioz/httpz.go b/libsq/core/ioz/httpz.go index f66736f83..173d66f87 100644 --- a/libsq/core/ioz/httpz.go +++ b/libsq/core/ioz/httpz.go @@ -3,12 +3,25 @@ package ioz import ( "context" "crypto/tls" + "github.com/neilotoole/sq/cli/buildinfo" "net/http" + "strings" "time" "github.com/neilotoole/sq/libsq/core/errz" ) +// NewDefaultHTTPClient returns a new HTTP client with default settings. +func NewDefaultHTTPClient() *http.Client { + v := buildinfo.Get().Version + v = strings.TrimPrefix(v, "v") + if v != "" { + v = "sq/" + v + } + + return NewHTTPClient(v, true, 0, 0) +} + // NewHTTPClient returns a new HTTP client. If userAgent is non-empty, the // "User-Agent" header is applied to each request. If insecureSkipVerify is // true, the client will skip TLS verification. If headerTimeout > 0, a diff --git a/libsq/source/dl.go b/libsq/source/dl.go deleted file mode 100644 index a8b9ab18a..000000000 --- a/libsq/source/dl.go +++ /dev/null @@ -1,147 +0,0 @@ -package source - -import ( - "context" - "io" - "log/slog" - "net/http" - "sync" - - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/ioz/download" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/lg/lgm" -) - -// newDownloader creates a new downloader using cacheDir for the given url. -func newDownloader2(cacheDir, userAgent, dlURL string) (*downloader2, error) { - //dv := diskv.New(diskv.Options{ - // BasePath: filepath.Join(cacheDir, "cache"), - // TempDir: filepath.Join(cacheDir, "working"), - // CacheSizeMax: 10000 * 1024 * 1024, // 10000MB - //}) - if err := ioz.RequireDir(cacheDir); err != nil { - return nil, err - } - - // dc := diskcache.NewWithDiskv(dv) - rc := download.NewRespCache(cacheDir) - tp := download.New(nil, rc) - - // respCache := download.NewRespCache(cacheDir) - // tp.RespCache = respCache - // tp.BodyFilepath = filepath.Join(cacheDir, "body.data") - - c := &http.Client{Transport: tp} - - return &downloader2{ - c: c, - // dc: dc, - // dv: dv, - cacheDir: cacheDir, - url: dlURL, - userAgent: userAgent, - tp: tp, - }, nil -} - -type downloader2 struct { - c *http.Client - mu sync.Mutex - userAgent string - cacheDir string - url string - tp *download.Download -} - -func (d *downloader2) log(log *slog.Logger) *slog.Logger { - return log.With(lga.URL, d.url, lga.Dir, d.cacheDir) -} - -// ClearCache clears the cache dir. -func (d *downloader2) ClearCache(ctx context.Context) error { - d.mu.Lock() - defer d.mu.Unlock() - - if err := d.tp.Clear(ctx); err != nil { - return errz.Wrapf(err, "failed to clear cache dir: %s", d.cacheDir) - } - - return ioz.RequireDir(d.cacheDir) -} - -func (d *downloader2) Download(ctx context.Context, dest io.Writer) (written int64, fp string, err error) { - d.mu.Lock() - defer d.mu.Unlock() - - log := d.log(lg.FromContext(ctx)) - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil) - if d.userAgent != "" { - req.Header.Set("User-Agent", d.userAgent) - } - - resp, err := d.c.Do(req) - if err != nil { - return written, "", errz.Wrapf(err, "download failed for: %s", d.url) - } - defer func() { - if resp != nil && resp.Body != nil { - lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - } - }() - - written, err = io.Copy( - contextio.NewWriter(ctx, dest), - contextio.NewReader(ctx, resp.Body), - ) - - return written, "", err -} - -func (d *downloader2) Download2(ctx context.Context, dest io.Writer) (written int64, fp string, err error) { - d.mu.Lock() - defer d.mu.Unlock() - - log := d.log(lg.FromContext(ctx)) - _ = log - - var destWrtr io.WriteCloser - var ok bool - if destWrtr, ok = dest.(io.WriteCloser); !ok { - destWrtr = ioz.WriteCloser(dest) - } - - log.Debug("huzzah Download2") - req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil) - if d.userAgent != "" { - req.Header.Set("User-Agent", d.userAgent) - } - - var gotFp string - var gotErr error - // buf := &bytes.Buffer{} - cb := download.Handler{ - Cached: func(cachedFilepath string) error { - gotFp = cachedFilepath - return nil - }, - Uncached: func() (wc io.WriteCloser, errFn func(error), err error) { - return destWrtr, func(err error) { - gotErr = err - }, nil - }, - Error: func(err error) { - gotErr = err - }, - } - - d.tp.fetchWith(req, cb) - _ = gotFp - _ = gotErr - - return written, "", err -} diff --git a/libsq/source/dl_test.go b/libsq/source/dl_test.go deleted file mode 100644 index 7326c0780..000000000 --- a/libsq/source/dl_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package source - -import ( - "bytes" - "context" - "net/url" - "path" - "path/filepath" - "testing" - - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/stretchr/testify/require" -) - -func TestDownloader2_Download(t *testing.T) { - log := slogt.New(t) - ctx := lg.NewContext(context.Background(), log) - const dlURL = urlActorCSV - const wantContentLength = sizeActorCSV - u, err := url.Parse(dlURL) - require.NoError(t, err) - wantFilename := path.Base(u.Path) - require.Equal(t, "actor.csv", wantFilename) - - cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-2")) - require.NoError(t, err) - t.Logf("cacheDir: %s", cacheDir) - - log.Debug("huzzah") - dl, err := newDownloader2(cacheDir, "sq/dev", dlURL) - require.NoError(t, err) - // require.NoError(t, dl.ClearCache(ctx)) - - buf := &bytes.Buffer{} - written, cachedFp, err := dl.Download2(ctx, buf) - _ = written - _ = cachedFp - require.NoError(t, err) - // require.Equal(t, wantContentLength, written) - // require.Equal(t, wantContentLength, int64(buf.Len())) - - buf.Reset() - written, cachedFp, err = dl.Download2(ctx, buf) - require.NoError(t, err) - // require.Equal(t, wantContentLength, written) - // require.Equal(t, wantContentLength, int64(buf.Len())) -} - -func TestDownloader2_Download_Legacy(t *testing.T) { - ctx := lg.NewContext(context.Background(), slogt.New(t)) - const dlURL = urlActorCSV - const wantContentLength = sizeActorCSV - u, err := url.Parse(dlURL) - require.NoError(t, err) - wantFilename := path.Base(u.Path) - require.Equal(t, "actor.csv", wantFilename) - - cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-2")) - require.NoError(t, err) - t.Logf("cacheDir: %s", cacheDir) - - dl, err := newDownloader2(cacheDir, "sq/dev", dlURL) - require.NoError(t, err) - // require.NoError(t, dl.ClearCache(ctx)) - - buf := &bytes.Buffer{} - written, cachedFp, err := dl.Download2(ctx, buf) - _ = cachedFp - require.NoError(t, err) - require.Equal(t, wantContentLength, written) - require.Equal(t, wantContentLength, int64(buf.Len())) - - buf.Reset() - written, cachedFp, err = dl.Download(ctx, buf) - require.NoError(t, err) - require.Equal(t, wantContentLength, written) - require.Equal(t, wantContentLength, int64(buf.Len())) -} diff --git a/libsq/source/download.go b/libsq/source/download.go index 27ea099f7..a1b64aaed 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -3,14 +3,13 @@ package source import ( "bytes" "context" + "github.com/neilotoole/sq/libsq/core/ioz/download" "io" "log/slog" - "mime" "net/http" "net/http/httputil" "net/url" "os" - "path" "path/filepath" "sync" "time" @@ -25,7 +24,6 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" - "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source/fetcher" ) @@ -176,7 +174,7 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 return written, "", errz.Errorf("download failed with %s for %s", resp.Status, d.url) } - filename := getDownloadFilename(resp) + filename := download.Filename(resp) if filename == "" { filename = "download" } @@ -420,30 +418,3 @@ func (fs *Files) fetch(ctx context.Context, loc string) (fpath string, err error return dlFile.Name(), nil } - -// getDownloadFilename returns the filename to use for a download. -// It first checks the Content-Disposition header, and if that's -// not present, it uses the last path segment of the URL. The -// filename is sanitized. -// It's possible that the returned value will be empty string; the -// caller should handle that situation themselves. -func getDownloadFilename(resp *http.Response) string { - var filename string - if resp == nil || resp.Header == nil { - return "" - } - dispHeader := resp.Header.Get("Content-Disposition") - if dispHeader != "" { - if _, params, err := mime.ParseMediaType(dispHeader); err == nil { - filename = params["filename"] - } - } - - if filename == "" { - filename = path.Base(resp.Request.URL.Path) - } else { - filename = filepath.Base(filename) - } - - return stringz.SanitizeFilename(filename) -} From 389719b0e3876ba28297770018b08d0a001b53d2 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 12:42:37 -0700 Subject: [PATCH 106/195] Created httpz pkg --- cli/buildinfo/buildinfo.go | 42 +++- cli/output/jsonw/versionwriter.go | 2 +- cli/output/tablew/versionwriter.go | 2 +- cli/output/writers.go | 2 +- cli/output/yamlw/versionwriter.go | 2 +- libsq/core/ioz/download/download.go | 9 +- libsq/core/ioz/download/download_test.go | 5 +- libsq/core/ioz/download/httpz.go | 192 --------------- libsq/core/ioz/download/respcache.go | 3 +- libsq/core/ioz/httpz.go | 137 ----------- libsq/core/ioz/httpz/httpz.go | 298 +++++++++++++++++++++++ libsq/core/ioz/{ => httpz}/httpz_test.go | 6 +- libsq/source/download.go | 4 +- 13 files changed, 347 insertions(+), 357 deletions(-) delete mode 100644 libsq/core/ioz/httpz.go create mode 100644 libsq/core/ioz/httpz/httpz.go rename libsq/core/ioz/{ => httpz}/httpz_test.go (96%) diff --git a/cli/buildinfo/buildinfo.go b/cli/buildinfo/buildinfo.go index 29a4730c9..28613d1ba 100644 --- a/cli/buildinfo/buildinfo.go +++ b/cli/buildinfo/buildinfo.go @@ -33,15 +33,15 @@ var ( Timestamp string ) -// BuildInfo encapsulates Version, Commit and Timestamp. -type BuildInfo struct { +// Info encapsulates Version, Commit and Timestamp. +type Info struct { Version string `json:"version" yaml:"version"` Commit string `json:"commit,omitempty" yaml:"commit,omitempty"` Timestamp time.Time `json:"timestamp,omitempty" yaml:"timestamp,omitempty"` } -// String returns a string representation of BuildInfo. -func (bi BuildInfo) String() string { +// String returns a string representation of Info. +func (bi Info) String() string { s := bi.Version if bi.Commit != "" { s += " " + bi.Commit @@ -52,8 +52,30 @@ func (bi BuildInfo) String() string { return s } +// UserAgent returns a string suitable for use in an HTTP User-Agent header. +func (bi Info) UserAgent() string { + if bi.Version == "" { + return "sq/0.0.0-dev" + } + + ua := "sq/" + strings.TrimPrefix(bi.Version, "v") + return ua +} + +// ShortCommit returns the short commit hash. +func (bi Info) ShortCommit() string { + switch { + case bi.Commit == "": + return "" + case len(bi.Commit) > 7: + return bi.Commit[:7] + default: + return bi.Commit + } +} + // LogValue implements slog.LogValuer. -func (bi BuildInfo) LogValue() slog.Value { +func (bi Info) LogValue() slog.Value { gv := slog.GroupValue( slog.String(lga.Version, bi.Version), slog.String(lga.Commit, bi.Commit), @@ -62,9 +84,9 @@ func (bi BuildInfo) LogValue() slog.Value { return gv } -// Get returns BuildInfo. If buildinfo.Timestamp cannot be parsed, -// the returned BuildInfo.Timestamp will be the zero value. -func Get() BuildInfo { +// Get returns Info. If buildinfo.Timestamp cannot be parsed, +// the returned Info.Timestamp will be the zero value. +func Get() Info { var t time.Time if Timestamp != "" { got, err := timez.ParseTimestampUTC(Timestamp) @@ -73,7 +95,7 @@ func Get() BuildInfo { } } - return BuildInfo{ + return Info{ Version: Version, Commit: Commit, Timestamp: t, @@ -88,7 +110,7 @@ func init() { //nolint:gochecknoinits if Version != "" && !semver.IsValid(Version) { // We want to panic here because it is a pipeline/build failure // to have an invalid non-empty Version. - panic(fmt.Sprintf("Invalid BuildInfo.Version value: %s", Version)) + panic(fmt.Sprintf("Invalid Info.Version value: %s", Version)) } if Timestamp != "" { diff --git a/cli/output/jsonw/versionwriter.go b/cli/output/jsonw/versionwriter.go index ac92a7325..dcd976d54 100644 --- a/cli/output/jsonw/versionwriter.go +++ b/cli/output/jsonw/versionwriter.go @@ -23,7 +23,7 @@ func NewVersionWriter(out io.Writer, pr *output.Printing) output.VersionWriter { } // Version implements output.VersionWriter. -func (w *versionWriter) Version(bi buildinfo.BuildInfo, latestVersion string, hi hostinfo.Info) error { +func (w *versionWriter) Version(bi buildinfo.Info, latestVersion string, hi hostinfo.Info) error { // We use a custom struct so that we can // control the timestamp format. type cliBuildInfo struct { diff --git a/cli/output/tablew/versionwriter.go b/cli/output/tablew/versionwriter.go index 98a75323c..7513154b2 100644 --- a/cli/output/tablew/versionwriter.go +++ b/cli/output/tablew/versionwriter.go @@ -26,7 +26,7 @@ func NewVersionWriter(out io.Writer, pr *output.Printing) output.VersionWriter { } // Version implements output.VersionWriter. -func (w *versionWriter) Version(bi buildinfo.BuildInfo, latestVersion string, hi hostinfo.Info) error { +func (w *versionWriter) Version(bi buildinfo.Info, latestVersion string, hi hostinfo.Info) error { var newerAvailable bool if latestVersion != "" { diff --git a/cli/output/writers.go b/cli/output/writers.go index 345b1690c..7c715f508 100644 --- a/cli/output/writers.go +++ b/cli/output/writers.go @@ -120,7 +120,7 @@ type VersionWriter interface { // Version prints version info. Arg latestVersion is the latest // version available from the homebrew repository. The value // may be empty. - Version(bi buildinfo.BuildInfo, latestVersion string, si hostinfo.Info) error + Version(bi buildinfo.Info, latestVersion string, si hostinfo.Info) error } // ConfigWriter prints config. diff --git a/cli/output/yamlw/versionwriter.go b/cli/output/yamlw/versionwriter.go index 26fe2fffb..ed1eaf561 100644 --- a/cli/output/yamlw/versionwriter.go +++ b/cli/output/yamlw/versionwriter.go @@ -26,7 +26,7 @@ func NewVersionWriter(out io.Writer, pr *output.Printing) output.VersionWriter { } // Version implements output.VersionWriter. -func (w *versionWriter) Version(bi buildinfo.BuildInfo, latestVersion string, hi hostinfo.Info) error { +func (w *versionWriter) Version(bi buildinfo.Info, latestVersion string, hi hostinfo.Info) error { type cliBuildInfo struct { Version string `json:"version" yaml:"version"` Commit string `json:"commit,omitempty" yaml:"commit,omitempty"` diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index c1ff18d19..ca38b2863 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -12,6 +12,7 @@ import ( "bufio" "context" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "io" "net/http" "net/url" @@ -86,8 +87,6 @@ type Download struct { // cache will be given an extra header, X-From-Cache markCachedResponses bool - userAgent string - disableCaching bool } @@ -335,9 +334,7 @@ func (dl *Download) newRequest(ctx context.Context, url string) (*http.Request, lg.FromContext(ctx).Error("Failed to create request", lga.URL, url, lga.Err, err) return nil, err } - if dl.userAgent != "" { - req.Header.Set("User-Agent", dl.userAgent) - } + return req, nil } @@ -379,7 +376,7 @@ func (dl *Download) state(req *http.Request) State { defer lg.WarnIfCloseError(log, "Close cached response header file", f) - cachedResp, err := readResponseHeader(bufio.NewReader(f), nil) + cachedResp, err := httpz.ReadResponseHeader(bufio.NewReader(f), nil) if err != nil { log.Error("Failed to read cached response header", lga.Err, err) return Uncached diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index d124da9ce..42dac35f3 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -3,6 +3,7 @@ package download_test import ( "bytes" "context" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/testh/tu" "github.com/stretchr/testify/assert" @@ -41,7 +42,7 @@ func TestDownload_redirect(t *testing.T) { var srvr *httptest.Server srvr = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log := log.With("origin", "server") - log.Info("Request on /actual", "req", download.RequestLogValue(r)) + log.Info("Request on /actual", "req", httpz.RequestLogValue(r)) switch r.URL.Path { case "/redirect": loc := srvr.URL + "/actual" @@ -81,7 +82,7 @@ func TestDownload_redirect(t *testing.T) { ctx := lg.NewContext(context.Background(), log.With("origin", "downloader")) loc := srvr.URL + "/redirect" - dl, err := download.New(ioz.NewDefaultHTTPClient(), loc, cacheDir) + dl, err := download.New(httpz.NewDefaultClient(), loc, cacheDir) require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) h := newTestHandler(log.With("origin", "handler")) diff --git a/libsq/core/ioz/download/httpz.go b/libsq/core/ioz/download/httpz.go index ed18d5c57..b7fc01f5e 100644 --- a/libsq/core/ioz/download/httpz.go +++ b/libsq/core/ioz/download/httpz.go @@ -4,86 +4,11 @@ import ( "bufio" "bytes" "errors" - "fmt" - "github.com/neilotoole/sq/libsq/core/stringz" - "io" - "log/slog" - "mime" "net/http" - "net/textproto" - "path" - "path/filepath" - "strconv" "strings" "time" ) -// readResponseHeader is a fork of http.ReadResponse that reads only the -// header from req and not the body. Note that resp.Body will be nil, and -// that the resp object is borked for general use. -func readResponseHeader(r *bufio.Reader, req *http.Request) (resp *http.Response, err error) { - tp := textproto.NewReader(r) - resp = &http.Response{Request: req} - - // Parse the first line of the response. - line, err := tp.ReadLine() - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return nil, err - } - proto, status, ok := strings.Cut(line, " ") - if !ok { - return nil, badStringError("malformed HTTP response", line) - } - resp.Proto = proto - resp.Status = strings.TrimLeft(status, " ") - - statusCode, _, _ := strings.Cut(resp.Status, " ") - if len(statusCode) != 3 { - return nil, badStringError("malformed HTTP status code", statusCode) - } - resp.StatusCode, err = strconv.Atoi(statusCode) - if err != nil || resp.StatusCode < 0 { - return nil, badStringError("malformed HTTP status code", statusCode) - } - if resp.ProtoMajor, resp.ProtoMinor, ok = http.ParseHTTPVersion(resp.Proto); !ok { - return nil, badStringError("malformed HTTP version", resp.Proto) - } - - // Parse the response headers. - mimeHeader, err := tp.ReadMIMEHeader() - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return nil, err - } - resp.Header = http.Header(mimeHeader) - - fixPragmaCacheControl(resp.Header) - - return resp, nil -} - -// RFC 7234, section 5.4: Should treat -// -// Pragma: no-cache -// -// like -// -// Cache-Control: no-cache -func fixPragmaCacheControl(header http.Header) { - if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { - if _, presentcc := header["Cache-Control"]; !presentcc { - header["Cache-Control"] = []string{"no-cache"} - } - } -} - -func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } - // errNoDateHeader indicates that the HTTP headers contained no Date header. var errNoDateHeader = errors.New("no Date header") @@ -348,35 +273,6 @@ func headerAllCommaSepValues(headers http.Header, name string) []string { return vals } -// cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF -// handler with a full copy of the content read from R when EOF is -// reached. -type cachingReadCloser struct { - // Underlying ReadCloser. - R io.ReadCloser - // OnEOF is called with a copy of the content of R when EOF is reached. - OnEOF func(io.Reader) - - buf bytes.Buffer // buf stores a copy of the content of R. -} - -// Read reads the next len(p) bytes from R or until R is drained. The -// return value n is the number of bytes read. If R has no data to -// return, err is io.EOF and OnEOF is called with a full copy of what -// has been read so far. -func (r *cachingReadCloser) Read(p []byte) (n int, err error) { - n, err = r.R.Read(p) - r.buf.Write(p[:n]) - if err == io.EOF { - r.OnEOF(bytes.NewReader(r.buf.Bytes())) - } - return n, err -} - -func (r *cachingReadCloser) Close() error { - return r.R.Close() -} - // varyMatches will return false unless all the cached values for the // headers listed in Vary match the new request func varyMatches(cachedResp *http.Response, req *http.Request) bool { @@ -388,91 +284,3 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { } return true } - -// ResponseLogValue implements slog.Valuer for resp. -func ResponseLogValue(resp *http.Response) slog.Value { - if resp == nil { - return slog.Value{} - } - - attrs := []slog.Attr{ - slog.String("proto", resp.Proto), - slog.String("status", resp.Status), - } - - h := resp.Header - for k, _ := range h { - vals := h.Values(k) - if len(vals) == 1 { - attrs = append(attrs, slog.String(k, vals[0])) - continue - } - - attrs = append(attrs, slog.Any(k, h.Get(k))) - } - - if resp.Request != nil { - attrs = append(attrs, slog.Any("req", RequestLogValue(resp.Request))) - } - - return slog.GroupValue(attrs...) -} - -// RequestLogValue implements slog.Valuer for req. -func RequestLogValue(req *http.Request) slog.Value { - if req == nil { - return slog.Value{} - } - - attrs := []slog.Attr{ - slog.String("method", req.Method), - slog.String("path", req.URL.RawPath), - } - - if req.Proto != "" { - attrs = append(attrs, slog.String("proto", req.Proto)) - } - if req.Host != "" { - attrs = append(attrs, slog.String("host", req.Host)) - } - - h := req.Header - for k, _ := range h { - vals := h.Values(k) - if len(vals) == 1 { - attrs = append(attrs, slog.String(k, vals[0])) - continue - } - - attrs = append(attrs, slog.Any(k, h.Get(k))) - } - - return slog.GroupValue(attrs...) -} - -// Filename returns the filename to use for a download. -// It first checks the Content-Disposition header, and if that's -// not present, it uses the last path segment of the URL. The -// filename is sanitized. -// It's possible that the returned value will be empty string; the -// caller should handle that situation themselves. -func Filename(resp *http.Response) string { - var filename string - if resp == nil || resp.Header == nil { - return "" - } - dispHeader := resp.Header.Get("Content-Disposition") - if dispHeader != "" { - if _, params, err := mime.ParseMediaType(dispHeader); err == nil { - filename = params["filename"] - } - } - - if filename == "" { - filename = path.Base(resp.Request.URL.Path) - } else { - filename = filepath.Base(filename) - } - - return stringz.SanitizeFilename(filename) -} diff --git a/libsq/core/ioz/download/respcache.go b/libsq/core/ioz/download/respcache.go index ce1abbee3..52b9ee0b0 100644 --- a/libsq/core/ioz/download/respcache.go +++ b/libsq/core/ioz/download/respcache.go @@ -6,6 +6,7 @@ import ( "context" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg/lgm" "io" "net/http" @@ -159,7 +160,7 @@ func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, return err } - log.Debug("Writing HTTP response to cache", lga.Dir, rc.Dir, "resp", ResponseLogValue(resp)) + log.Debug("Writing HTTP response to cache", lga.Dir, rc.Dir, "resp", httpz.ResponseLogValue(resp)) fpHeader, fpBody := rc.Paths(resp.Request) headerBytes, err := httputil.DumpResponse(resp, false) diff --git a/libsq/core/ioz/httpz.go b/libsq/core/ioz/httpz.go deleted file mode 100644 index 173d66f87..000000000 --- a/libsq/core/ioz/httpz.go +++ /dev/null @@ -1,137 +0,0 @@ -package ioz - -import ( - "context" - "crypto/tls" - "github.com/neilotoole/sq/cli/buildinfo" - "net/http" - "strings" - "time" - - "github.com/neilotoole/sq/libsq/core/errz" -) - -// NewDefaultHTTPClient returns a new HTTP client with default settings. -func NewDefaultHTTPClient() *http.Client { - v := buildinfo.Get().Version - v = strings.TrimPrefix(v, "v") - if v != "" { - v = "sq/" + v - } - - return NewHTTPClient(v, true, 0, 0) -} - -// NewHTTPClient returns a new HTTP client. If userAgent is non-empty, the -// "User-Agent" header is applied to each request. If insecureSkipVerify is -// true, the client will skip TLS verification. If headerTimeout > 0, a -// timeout is applied to receiving the HTTP response, but that timeout is -// not applied to reading the response body. This is useful if you expect -// a response within, say, 5 seconds, but you expect the body to take longer -// to read. If bodyTimeout > 0, it is applied to the total lifecycle of -// the request and response, including reading the response body. -func NewHTTPClient(userAgent string, insecureSkipVerify bool, - headerTimeout, bodyTimeout time.Duration, -) *http.Client { - c := *http.DefaultClient - var tr *http.Transport - if c.Transport == nil { - tr = (http.DefaultTransport.(*http.Transport)).Clone() - } else { - tr = (c.Transport.(*http.Transport)).Clone() - } - - if tr.TLSClientConfig == nil { - // We allow tls.VersionTLS10, even though it's not considered - // secure these days. Ultimately this could become a config - // option. - tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS10} //nolint:gosec - } else { - tr.TLSClientConfig = tr.TLSClientConfig.Clone() - tr.TLSClientConfig.MinVersion = tls.VersionTLS10 //nolint:gosec - } - - tr.TLSClientConfig.InsecureSkipVerify = insecureSkipVerify - c.Transport = tr - if userAgent != "" { - c.Transport = &userAgentRoundTripper{ - userAgent: userAgent, - rt: c.Transport, - } - } - - c.Timeout = bodyTimeout - if headerTimeout > 0 { - c.Transport = &headerTimeoutRoundTripper{ - headerTimeout: headerTimeout, - rt: c.Transport, - } - } - - return &c -} - -// userAgentRoundTripper applies a User-Agent header to each request. -type userAgentRoundTripper struct { - userAgent string - rt http.RoundTripper -} - -// RoundTrip implements http.RoundTripper. -func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set("User-Agent", rt.userAgent) - return rt.rt.RoundTrip(req) -} - -// headerTimeoutRoundTripper applies headerTimeout to the return of the http -// response, but headerTimeout is not applied to reading the body of the -// response. This is useful if you expect a response within, say, 5 seconds, -// but you expect the body to take longer to read. -type headerTimeoutRoundTripper struct { - headerTimeout time.Duration - rt http.RoundTripper -} - -// RoundTrip implements http.RoundTripper. -func (rt *headerTimeoutRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if rt.headerTimeout <= 0 { - return rt.rt.RoundTrip(req) - } - - timerCancelCh := make(chan struct{}) - ctx, cancelFn := context.WithCancelCause(req.Context()) - go func() { - t := time.NewTimer(rt.headerTimeout) - defer t.Stop() - select { - case <-ctx.Done(): - case <-t.C: - cancelFn(errz.Errorf("http response not received by %d timeout", - rt.headerTimeout)) - case <-timerCancelCh: - // Stop the timer goroutine. - } - }() - - resp, err := rt.rt.RoundTrip(req.WithContext(ctx)) - close(timerCancelCh) - - // Don't leak resources; ensure that cancelFn is eventually called. - switch { - case err != nil: - - // It's possible that cancelFn has already been called by the - // timer goroutine, but we call it again just in case. - cancelFn(err) - case resp != nil && resp.Body != nil: - - // Wrap resp.Body with a ReadCloserNotifier, so that cancelFn - // is called when the body is closed. - resp.Body = ReadCloserNotifier(resp.Body, cancelFn) - default: - // Not sure if this can actually happen, but just in case. - cancelFn(context.Canceled) - } - - return resp, err -} diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go new file mode 100644 index 000000000..faff1a9c2 --- /dev/null +++ b/libsq/core/ioz/httpz/httpz.go @@ -0,0 +1,298 @@ +// Package httpz provides functionality supplemental to stdlib http. +// Indeed, some of the functions are copied verbatim from stdlib. +package httpz + +import ( + "bufio" + "context" + "crypto/tls" + "fmt" + "github.com/neilotoole/sq/cli/buildinfo" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/stringz" + "io" + "log/slog" + "mime" + "net/http" + "net/textproto" + "path" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/neilotoole/sq/libsq/core/errz" +) + +// NewDefaultClient returns a new HTTP client with default settings. +func NewDefaultClient() *http.Client { + return NewClient(buildinfo.Get().UserAgent(), true, 0, 0) +} + +// NewClient returns a new HTTP client. If userAgent is non-empty, the +// "User-Agent" header is applied to each request. If insecureSkipVerify is +// true, the client will skip TLS verification. If headerTimeout > 0, a +// timeout is applied to receiving the HTTP response, but that timeout is +// not applied to reading the response body. This is useful if you expect +// a response within, say, 5 seconds, but you expect the body to take longer +// to read. If bodyTimeout > 0, it is applied to the total lifecycle of +// the request and response, including reading the response body. +func NewClient(userAgent string, insecureSkipVerify bool, + headerTimeout, bodyTimeout time.Duration, +) *http.Client { + c := *http.DefaultClient + var tr *http.Transport + if c.Transport == nil { + tr = (http.DefaultTransport.(*http.Transport)).Clone() + } else { + tr = (c.Transport.(*http.Transport)).Clone() + } + + if tr.TLSClientConfig == nil { + // We allow tls.VersionTLS10, even though it's not considered + // secure these days. Ultimately this could become a config + // option. + tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS10} //nolint:gosec + } else { + tr.TLSClientConfig = tr.TLSClientConfig.Clone() + tr.TLSClientConfig.MinVersion = tls.VersionTLS10 //nolint:gosec + } + + tr.TLSClientConfig.InsecureSkipVerify = insecureSkipVerify + c.Transport = tr + if userAgent != "" { + c.Transport = &userAgentRoundTripper{ + userAgent: userAgent, + rt: c.Transport, + } + } + + c.Timeout = bodyTimeout + if headerTimeout > 0 { + c.Transport = &headerTimeoutRoundTripper{ + headerTimeout: headerTimeout, + rt: c.Transport, + } + } + + return &c +} + +// userAgentRoundTripper applies a User-Agent header to each request. +type userAgentRoundTripper struct { + userAgent string + rt http.RoundTripper +} + +// RoundTrip implements http.RoundTripper. +func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", rt.userAgent) + return rt.rt.RoundTrip(req) +} + +// headerTimeoutRoundTripper applies headerTimeout to the return of the http +// response, but headerTimeout is not applied to reading the body of the +// response. This is useful if you expect a response within, say, 5 seconds, +// but you expect the body to take longer to read. +type headerTimeoutRoundTripper struct { + headerTimeout time.Duration + rt http.RoundTripper +} + +// RoundTrip implements http.RoundTripper. +func (rt *headerTimeoutRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if rt.headerTimeout <= 0 { + return rt.rt.RoundTrip(req) + } + + timerCancelCh := make(chan struct{}) + ctx, cancelFn := context.WithCancelCause(req.Context()) + go func() { + t := time.NewTimer(rt.headerTimeout) + defer t.Stop() + select { + case <-ctx.Done(): + case <-t.C: + cancelFn(errz.Errorf("http response not received by %s timeout", + rt.headerTimeout)) + case <-timerCancelCh: + // Stop the timer goroutine. + } + }() + + resp, err := rt.rt.RoundTrip(req.WithContext(ctx)) + close(timerCancelCh) + + // Don't leak resources; ensure that cancelFn is eventually called. + switch { + case err != nil: + + // It's possible that cancelFn has already been called by the + // timer goroutine, but we call it again just in case. + cancelFn(err) + case resp != nil && resp.Body != nil: + + // Wrap resp.Body with a ReadCloserNotifier, so that cancelFn + // is called when the body is closed. + resp.Body = ioz.ReadCloserNotifier(resp.Body, cancelFn) + default: + // Not sure if this can actually happen, but just in case. + cancelFn(context.Canceled) + } + + return resp, err +} + +// ResponseLogValue implements slog.Valuer for resp. +func ResponseLogValue(resp *http.Response) slog.Value { + if resp == nil { + return slog.Value{} + } + + attrs := []slog.Attr{ + slog.String("proto", resp.Proto), + slog.String("status", resp.Status), + } + + h := resp.Header + for k, _ := range h { + vals := h.Values(k) + if len(vals) == 1 { + attrs = append(attrs, slog.String(k, vals[0])) + continue + } + + attrs = append(attrs, slog.Any(k, h.Get(k))) + } + + if resp.Request != nil { + attrs = append(attrs, slog.Any("req", RequestLogValue(resp.Request))) + } + + return slog.GroupValue(attrs...) +} + +// RequestLogValue implements slog.Valuer for req. +func RequestLogValue(req *http.Request) slog.Value { + if req == nil { + return slog.Value{} + } + + attrs := []slog.Attr{ + slog.String("method", req.Method), + slog.String("path", req.URL.RawPath), + } + + if req.Proto != "" { + attrs = append(attrs, slog.String("proto", req.Proto)) + } + if req.Host != "" { + attrs = append(attrs, slog.String("host", req.Host)) + } + + h := req.Header + for k, _ := range h { + vals := h.Values(k) + if len(vals) == 1 { + attrs = append(attrs, slog.String(k, vals[0])) + continue + } + + attrs = append(attrs, slog.Any(k, h.Get(k))) + } + + return slog.GroupValue(attrs...) +} + +// Filename returns the filename to use for a download. +// It first checks the Content-Disposition header, and if that's +// not present, it uses the last path segment of the URL. The +// filename is sanitized. +// It's possible that the returned value will be empty string; the +// caller should handle that situation themselves. +func Filename(resp *http.Response) string { + var filename string + if resp == nil || resp.Header == nil { + return "" + } + dispHeader := resp.Header.Get("Content-Disposition") + if dispHeader != "" { + if _, params, err := mime.ParseMediaType(dispHeader); err == nil { + filename = params["filename"] + } + } + + if filename == "" { + filename = path.Base(resp.Request.URL.Path) + } else { + filename = filepath.Base(filename) + } + + return stringz.SanitizeFilename(filename) +} + +// ReadResponseHeader is a fork of http.ReadResponse that reads only the +// header from req and not the body. Note that resp.Body will be nil, and +// that the resp object is borked for general use. +func ReadResponseHeader(r *bufio.Reader, req *http.Request) (resp *http.Response, err error) { + tp := textproto.NewReader(r) + resp = &http.Response{Request: req} + + // Parse the first line of the response. + line, err := tp.ReadLine() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + proto, status, ok := strings.Cut(line, " ") + if !ok { + return nil, badStringError("malformed HTTP response", line) + } + resp.Proto = proto + resp.Status = strings.TrimLeft(status, " ") + + statusCode, _, _ := strings.Cut(resp.Status, " ") + if len(statusCode) != 3 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + resp.StatusCode, err = strconv.Atoi(statusCode) + if err != nil || resp.StatusCode < 0 { + return nil, badStringError("malformed HTTP status code", statusCode) + } + if resp.ProtoMajor, resp.ProtoMinor, ok = http.ParseHTTPVersion(resp.Proto); !ok { + return nil, badStringError("malformed HTTP version", resp.Proto) + } + + // Parse the response headers. + mimeHeader, err := tp.ReadMIMEHeader() + if err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + return nil, err + } + resp.Header = http.Header(mimeHeader) + + fixPragmaCacheControl(resp.Header) + + return resp, nil +} + +// RFC 7234, section 5.4: Should treat +// +// Pragma: no-cache +// +// like +// +// Cache-Control: no-cache +func fixPragmaCacheControl(header http.Header) { + if hp, ok := header["Pragma"]; ok && len(hp) > 0 && hp[0] == "no-cache" { + if _, presentcc := header["Cache-Control"]; !presentcc { + header["Cache-Control"] = []string{"no-cache"} + } + } +} + +func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } diff --git a/libsq/core/ioz/httpz_test.go b/libsq/core/ioz/httpz/httpz_test.go similarity index 96% rename from libsq/core/ioz/httpz_test.go rename to libsq/core/ioz/httpz/httpz_test.go index cf4c6d24a..1e0cd2df2 100644 --- a/libsq/core/ioz/httpz_test.go +++ b/libsq/core/ioz/httpz/httpz_test.go @@ -1,14 +1,14 @@ -package ioz_test +package httpz_test import ( "context" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "io" "net/http" "net/http/httptest" "testing" "time" - "github.com/neilotoole/sq/libsq/core/ioz" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -42,7 +42,7 @@ func TestNewHTTPClient_headerTimeout(t *testing.T) { ctxFn: func(t *testing.T) context.Context { return context.Background() }, - c: ioz.NewHTTPClient("", false, headerTimeout, 0), + c: httpz.NewClient("", false, headerTimeout, 0), wantErr: false, }, } diff --git a/libsq/source/download.go b/libsq/source/download.go index a1b64aaed..95a6730bf 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -3,7 +3,7 @@ package source import ( "bytes" "context" - "github.com/neilotoole/sq/libsq/core/ioz/download" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "io" "log/slog" "net/http" @@ -174,7 +174,7 @@ func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int6 return written, "", errz.Errorf("download failed with %s for %s", resp.Status, d.url) } - filename := download.Filename(resp) + filename := httpz.Filename(resp) if filename == "" { filename = "download" } From e5f63f38322f33af1234541e03a5e44ac47200c1 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 13:30:09 -0700 Subject: [PATCH 107/195] wip: tidy --- libsq/core/ioz/checksum/checksum.go | 16 +-- libsq/core/ioz/checksum/checksum_test.go | 8 +- libsq/core/ioz/download/download.go | 122 ++++++++++++++--------- libsq/core/ioz/download/download_test.go | 63 +++++------- libsq/core/ioz/download/respcache.go | 58 +++++++++-- libsq/source/cache.go | 2 +- 6 files changed, 162 insertions(+), 107 deletions(-) diff --git a/libsq/core/ioz/checksum/checksum.go b/libsq/core/ioz/checksum/checksum.go index 6f5136df4..961b03218 100644 --- a/libsq/core/ioz/checksum/checksum.go +++ b/libsq/core/ioz/checksum/checksum.go @@ -15,8 +15,12 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" ) -// Hash returns the hash of b as a hex string. -func Hash(b []byte) string { +// Sum returns the hash of b as a hex string. +// If b is empty, empty string is returned. +func Sum(b []byte) string { + if len(b) == 0 { + return "" + } sum := crc32.ChecksumIEEE(b) return fmt.Sprintf("%x", sum) } @@ -25,7 +29,7 @@ func Hash(b []byte) string { func Rand() string { b := make([]byte, 128) _, _ = rand.Read(b) - return Hash(b) + return Sum(b) } // Checksum is a checksum of a file. @@ -117,7 +121,7 @@ func ForFile(path string) (Checksum, error) { buf.WriteString(strconv.FormatUint(uint64(fi.Mode()), 10)) buf.WriteString(strconv.FormatBool(fi.IsDir())) - return Checksum(Hash(buf.Bytes())), nil + return Checksum(Sum(buf.Bytes())), nil } // ForHTTPHeader returns a checksum generated from URL u and @@ -142,7 +146,7 @@ func ForHTTPHeader(u string, header http.Header) Checksum { } } - return Checksum(Hash(buf.Bytes())) + return Checksum(Sum(buf.Bytes())) } // ForHTTPResponse returns a checksum generated from the response's @@ -202,5 +206,5 @@ func ForHTTPResponse(resp *http.Response) Checksum { fmt.Printf("\n\n%s\n\n", s) - return Checksum(Hash(buf.Bytes())) + return Checksum(Sum(buf.Bytes())) } diff --git a/libsq/core/ioz/checksum/checksum_test.go b/libsq/core/ioz/checksum/checksum_test.go index 09bf59b65..f5eaf193c 100644 --- a/libsq/core/ioz/checksum/checksum_test.go +++ b/libsq/core/ioz/checksum/checksum_test.go @@ -1,6 +1,7 @@ package checksum_test import ( + "github.com/stretchr/testify/require" "testing" "github.com/stretchr/testify/assert" @@ -9,7 +10,10 @@ import ( ) func TestHash(t *testing.T) { - got := checksum.Hash([]byte("hello world")) - t.Log(got) + got := checksum.Sum(nil) + require.Equal(t, "", got) + got = checksum.Sum([]byte{}) + require.Equal(t, "", got) + got = checksum.Sum([]byte("hello world")) assert.Equal(t, "d4a1185", got) } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index ca38b2863..a78a75b20 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -1,5 +1,5 @@ -// Package download provides a http.RoundTripper implementation that -// works as a mostly RFC-compliant cache for http responses. +// Package download provides a mechanism for getting files from +// HTTP URLs, making use of a mostly RFC-compliant cache. // // FIXME: move download to internal/download, because its use // is so specialized? @@ -12,7 +12,10 @@ import ( "bufio" "context" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/sq/libsq/core/lg/lgm" "io" "net/http" "net/url" @@ -53,6 +56,8 @@ const ( // XFromCache is the header added to responses that are returned from the cache const XFromCache = "X-From-Cache" +const msgNilDestWriter = "nil dest writer from download handler; returning" + // Opt is a configuration option for creating a new Download. type Opt func(t *Download) @@ -76,7 +81,8 @@ func OptDisableCaching(disable bool) Opt { type Download struct { // FIXME: Does Download need a sync.Mutex? - // url is the URL of the download. + // url is the URL of the download. It is parsed in download.New, + // thus is guaranteed to be valid. url string c *http.Client @@ -84,7 +90,7 @@ type Download struct { respCache *RespCache // markCachedResponses, if true, indicates that responses returned from the - // cache will be given an extra header, X-From-Cache + // cache will be given an extra header, X-From-Cache. markCachedResponses bool disableCaching bool @@ -95,7 +101,7 @@ type Download struct { func New(c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) { _, err := url.ParseRequestURI(dlURL) if err != nil { - return nil, err + return nil, errz.Wrap(err, "invalid download URL") } if c == nil { c = http.DefaultClient @@ -123,39 +129,37 @@ func New(c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) } // Handler is a callback invoked by Download.Get. Exactly one of the -// handler functions will be invoked, one time. +// handler functions will be invoked, exactly one time. type Handler struct { // Cached is invoked when the download is already cached on disk. The // fp arg is the path to the downloaded file. Cached func(fp string) - // Uncached is invoked when the download is not cached. The handler must + // Uncached is invoked when the download is not cached. The handler should // return an io.WriterCloser, which the download contents will be written // to (as well as being written to the disk cache). On success, the dest // io.WriteCloser is closed. If an error occurs during download or - // writing, errFn is invoked, and dest is not closed. + // writing, errFn is invoked, and dest is not closed. If the handler returns + // a nil dest io.WriteCloser, the Download will log a warning and return. Uncached func() (dest io.WriteCloser, errFn func(error)) - // Error is invoked if an + // Error is invoked on any error, other than writing to the destination + // io.WriteCloser returned by Handler.Uncached, which has its own error + // handling mechanism. Error func(err error) } // Get gets the download, invoking Handler as appropriate. func (dl *Download) Get(ctx context.Context, h Handler) { - req, err := dl.newRequest(ctx, dl.url) - if err != nil { - h.Error(err) - return - } - + req := dl.mustRequest(ctx) dl.get(req, h) } func (dl *Download) get(req *http.Request, h Handler) { ctx := req.Context() log := lg.FromContext(ctx) - log.Info("Fetching download", lga.URL, req.URL.String()) - _, fpBody := dl.respCache.Paths(req) + log.Debug("Get download", lga.URL, dl.url) + _, fpBody, _ := dl.respCache.Paths(req) state := dl.state(req) if state == Fresh { @@ -259,7 +263,7 @@ func (dl *Download) get(req *http.Request, h Handler) { } if resp == cachedResp { - lg.WarnIfCloseError(log, "Close response body", resp.Body) + lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) if err = dl.respCache.Write(ctx, resp, true, nil); err != nil { log.Error("Failed to update cache header", lga.Dir, dl.respCache.Dir, lga.Err, err) h.Error(err) @@ -270,17 +274,18 @@ func (dl *Download) get(req *http.Request, h Handler) { } // I'm not sure if this logic is even reachable? - copyWrtr, errFn := h.Uncached() - if copyWrtr == nil { - log.Warn("nil copy writer from download handler; returning") + destWrtr, errFn := h.Uncached() + if destWrtr == nil { + log.Warn(msgNilDestWriter) return } - defer lg.WarnIfCloseError(log, "Close response body", resp.Body) - if err = dl.respCache.Write(req.Context(), resp, false, copyWrtr); err != nil { - log.Error("failed to write download cache", lga.Dir, dl.respCache.Dir, lga.Err, err) - errFn(err) - h.Error(err) + defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) + if err = dl.respCache.Write(req.Context(), resp, false, destWrtr); err != nil { + log.Error("Failed to write download cache", lga.Dir, dl.respCache.Dir, lga.Err, err) + if errFn != nil { + errFn(err) + } } return } else { @@ -290,21 +295,20 @@ func (dl *Download) get(req *http.Request, h Handler) { // It's not cacheable, so we need to write it to the copyWrtr. copyWrtr, errFn := h.Uncached() if copyWrtr == nil { - log.Warn("nil copy writer from download handler; returning") + log.Warn(msgNilDestWriter) return } cr := contextio.NewReader(ctx, resp.Body) - defer lg.WarnIfCloseError(log, "Close response body", resp.Body) + defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) _, err = io.Copy(copyWrtr, cr) if err != nil { + log.Error("Failed to copy download to dest writer", lga.Err, err) errFn(err) - h.Error(err) return } if err = copyWrtr.Close(); err != nil { - h.Error(err) - return + log.Error("Failed to close dest writer", lga.Err, err) } return @@ -321,21 +325,20 @@ func (dl *Download) Close() error { // execRequest executes the request. func (dl *Download) execRequest(req *http.Request) (*http.Response, error) { - //if dl.transport == nil { - // return http.DefaultTransport.RoundTrip(req) - //} - //return dl.transport.RoundTrip(req) return dl.c.Do(req) } -func (dl *Download) newRequest(ctx context.Context, url string) (*http.Request, error) { - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) +// mustRequest creates a new request from dl.url. The url has already been +// parsed in download.New, so it's safe to ignore the error. +func (dl *Download) mustRequest(ctx context.Context) *http.Request { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, dl.url, nil) if err != nil { - lg.FromContext(ctx).Error("Failed to create request", lga.URL, url, lga.Err, err) - return nil, err + lg.FromContext(ctx).Error("Failed to create request", lga.URL, dl.url, lga.Err, err) + panic(err) + return nil } - return req, nil + return req } // Clear deletes the cache. @@ -348,11 +351,7 @@ func (dl *Download) Clear(ctx context.Context) error { // State returns the Download's cache state. func (dl *Download) State(ctx context.Context) State { - req, err := dl.newRequest(ctx, dl.url) - if err != nil { - return Uncached - } - return dl.state(req) + return dl.state(dl.mustRequest(ctx)) } func (dl *Download) state(req *http.Request) State { @@ -367,14 +366,14 @@ func (dl *Download) state(req *http.Request) State { return Uncached } - fpHeader, _ := dl.respCache.Paths(req) + fpHeader, _, _ := dl.respCache.Paths(req) f, err := os.Open(fpHeader) if err != nil { - log.Error("Failed to open cached response header file", lga.File, fpHeader, lga.Err, err) + log.Error(msgCloseCacheHeaderBody, lga.File, fpHeader, lga.Err, err) return Uncached } - defer lg.WarnIfCloseError(log, "Close cached response header file", f) + defer lg.WarnIfCloseError(log, msgCloseCacheHeaderBody, f) cachedResp, err := httpz.ReadResponseHeader(bufio.NewReader(f), nil) if err != nil { @@ -385,6 +384,33 @@ func (dl *Download) state(req *http.Request) State { return getFreshness(cachedResp.Header, req.Header) } +// Checksum returns the checksum of the cached download, if available. +func (dl *Download) Checksum(ctx context.Context) (sum checksum.Checksum, ok bool) { + if dl.respCache == nil { + return "", false + } + + req := dl.mustRequest(ctx) + + _, _, fp := dl.respCache.Paths(req) + if !ioz.FileAccessible(fp) { + return "", false + } + + sums, err := checksum.ReadFile(fp) + if err != nil { + lg.FromContext(ctx).Warn("Failed to read checksum file", lga.File, fp, lga.Err, err) + return "", false + } + + if len(sums) != 1 { + return "", false + } + + sum, ok = sums["body"] + return sum, ok +} + func (dl *Download) isCacheable(req *http.Request) bool { if dl.disableCaching { return false diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index 42dac35f3..e736b04b4 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -139,7 +139,7 @@ func TestDownload_redirect(t *testing.T) { //loc := srvr.URL + "/redirect" //loc := srvr.URL + "/actual" -func TestDownload(t *testing.T) { +func TestDownload_New(t *testing.T) { log := slogt.New(t) ctx := lg.NewContext(context.Background(), log) const dlURL = urlActorCSV @@ -152,54 +152,39 @@ func TestDownload(t *testing.T) { dl, err := download.New(nil, dlURL, cacheDir) require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) - - var ( - destBuf = &bytes.Buffer{} - gotFp string - gotErr error - ) - reset := func() { - destBuf.Reset() - gotFp = "" - gotErr = nil - } - - h := download.Handler{ - Cached: func(cachedFilepath string) { - gotFp = cachedFilepath - }, - Uncached: func() (wc io.WriteCloser, errFn func(error)) { - return ioz.WriteCloser(destBuf), - func(err error) { - gotErr = err - } - }, - Error: func(err error) { - gotErr = err - }, - } - require.Equal(t, download.Uncached, dl.State(ctx)) - dl.Get(ctx, h) - require.NoError(t, gotErr) - require.Empty(t, gotFp) - require.Equal(t, sizeActorCSV, int64(destBuf.Len())) + sum, ok := dl.Checksum(ctx) + require.False(t, ok) + require.Empty(t, sum) + h := newTestHandler(log.With("origin", "handler")) + dl.Get(ctx, h.Handler) + require.Empty(t, h.errors) + require.Empty(t, h.cacheFiles) + require.Equal(t, sizeActorCSV, int64(h.bufs[0].Len())) require.Equal(t, download.Fresh, dl.State(ctx)) + sum, ok = dl.Checksum(ctx) + require.True(t, ok) + require.NotEmpty(t, sum) - reset() - dl.Get(ctx, h) - require.NoError(t, gotErr) - require.Equal(t, 0, destBuf.Len()) - require.NotEmpty(t, gotFp) - gotFileBytes, err := os.ReadFile(gotFp) + h.reset() + dl.Get(ctx, h.Handler) + require.Empty(t, h.errors) + require.Empty(t, h.bufs) + require.NotEmpty(t, h.cacheFiles) + gotFileBytes, err := os.ReadFile(h.cacheFiles[0]) require.NoError(t, err) require.Equal(t, sizeActorCSV, int64(len(gotFileBytes))) - require.Equal(t, download.Fresh, dl.State(ctx)) + sum, ok = dl.Checksum(ctx) + require.True(t, ok) + require.NotEmpty(t, sum) require.NoError(t, dl.Clear(ctx)) require.Equal(t, download.Uncached, dl.State(ctx)) + sum, ok = dl.Checksum(ctx) + require.False(t, ok) + require.Empty(t, sum) } type testHandler struct { diff --git a/libsq/core/ioz/download/respcache.go b/libsq/core/ioz/download/respcache.go index 52b9ee0b0..a0acddce2 100644 --- a/libsq/core/ioz/download/respcache.go +++ b/libsq/core/ioz/download/respcache.go @@ -22,13 +22,13 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" ) +const msgCloseCacheHeaderBody = "Close cached response header file" + // NewRespCache returns a new instance that stores responses in cacheDir. // The caller should call RespCache.Close when finished with the cache. func NewRespCache(cacheDir string) *RespCache { c := &RespCache{ - Dir: cacheDir, - // Header: filepath.Join(cacheDir, "header"), - // Body: filepath.Join(cacheDir, "body"), + Dir: cacheDir, clnup: cleanup.New(), } return c @@ -44,15 +44,18 @@ type RespCache struct { Dir string } -// Paths returns the paths to the header and body files for req. +// Paths returns the paths to the header, body, and checksum files for req. // It is not guaranteed that they exist. -func (rc *RespCache) Paths(req *http.Request) (header, body string) { +func (rc *RespCache) Paths(req *http.Request) (header, body, checksum string) { if req == nil || req.Method == http.MethodGet { - return filepath.Join(rc.Dir, "header"), filepath.Join(rc.Dir, "body") + return filepath.Join(rc.Dir, "header"), + filepath.Join(rc.Dir, "body"), + filepath.Join(rc.Dir, "checksum.txt") } return filepath.Join(rc.Dir, req.Method+"_header"), - filepath.Join(rc.Dir, req.Method+"_body") + filepath.Join(rc.Dir, req.Method+"_body"), + filepath.Join(rc.Dir, req.Method+"_checksum.txt") } // Exists returns true if the cache contains a response for req. @@ -60,7 +63,7 @@ func (rc *RespCache) Exists(req *http.Request) bool { rc.mu.Lock() defer rc.mu.Unlock() - fpHeader, _ := rc.Paths(req) + fpHeader, _, _ := rc.Paths(req) fi, err := os.Stat(fpHeader) if err != nil { return false @@ -74,7 +77,7 @@ func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response rc.mu.Lock() defer rc.mu.Unlock() - fpHeader, fpBody := rc.Paths(req) + fpHeader, fpBody, _ := rc.Paths(req) if !ioz.FileAccessible(fpHeader) { return nil, nil @@ -97,7 +100,40 @@ func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response rc.clnup.AddC(bodyFile) // TODO: consider adding contextio.NewReader? concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) - return http.ReadResponse(bufio.NewReader(concatRdr), req) + resp, err := http.ReadResponse(bufio.NewReader(concatRdr), req) + if err != nil { + lg.WarnIfCloseError(lg.FromContext(ctx), "Close cached response body", bodyFile) + return nil, errz.Err(err) + } + return resp, nil + +} + +// Checksum returns the checksum of the cached body file, if available. +func (rc *RespCache) Checksum(req *http.Request) (sum checksum.Checksum, ok bool) { + if rc == nil || req == nil { + return "", false + } + + _, _, fp := rc.Paths(req) + if !ioz.FileAccessible(fp) { + return "", false + } + + sums, err := checksum.ReadFile(fp) + if err != nil { + lg.FromContext(req.Context()).Warn("Failed to read checksum file", + lga.File, fp, lga.Err, err) + return "", false + } + + if len(sums) != 1 { + // Shouldn't happen. + return "", false + } + + sum, ok = sums["body"] + return sum, ok } // Close closes the cache, freeing any resources it holds. Note that @@ -161,7 +197,7 @@ func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, } log.Debug("Writing HTTP response to cache", lga.Dir, rc.Dir, "resp", httpz.ResponseLogValue(resp)) - fpHeader, fpBody := rc.Paths(resp.Request) + fpHeader, fpBody, _ := rc.Paths(resp.Request) headerBytes, err := httputil.DumpResponse(resp, false) if err != nil { diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 5d4954579..2be8254f5 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -91,7 +91,7 @@ func (fs *Files) sourceHash(src *Source) string { } } - sum := checksum.Hash(buf.Bytes()) + sum := checksum.Sum(buf.Bytes()) return sum } From 93751be51e0c94639e267b470190280395390641 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 14:37:07 -0700 Subject: [PATCH 108/195] wip: tidy --- libsq/core/ioz/download/README.md | 6 +++ libsq/core/ioz/download/download.go | 12 ++--- libsq/core/ioz/download/{httpz.go => http.go} | 2 +- libsq/core/ioz/download/respcache.go | 45 +++++++++++++------ libsq/core/ioz/ioz.go | 8 ++-- testh/tu/tutil.go | 19 ++++++++ 6 files changed, 68 insertions(+), 24 deletions(-) rename libsq/core/ioz/download/{httpz.go => http.go} (99%) diff --git a/libsq/core/ioz/download/README.md b/libsq/core/ioz/download/README.md index 58cda222f..07cd41149 100644 --- a/libsq/core/ioz/download/README.md +++ b/libsq/core/ioz/download/README.md @@ -1,3 +1,9 @@ +# ACKNOWLEDGEMENT + +This package a heavily-modified fork +of [`gregjones/httpcache`](https://github.com/gregjones/httpcache). + + # httpcache [![GoDoc](https://godoc.org/github.com/bitcomplete/httpcache?status.svg)](https://godoc.org/github.com/bitcomplete/httpcache) diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index a78a75b20..f75bf7fb5 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -215,7 +215,7 @@ func (dl *Download) get(req *http.Request, h Handler) { } } - resp, err = dl.execRequest(req) + resp, err = dl.do(req) if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { // Replace the 304 response with the one from cache, but update with some new headers endToEndHeaders := getEndToEndHeaders(resp.Header) @@ -244,7 +244,7 @@ func (dl *Download) get(req *http.Request, h Handler) { if _, ok := reqCacheControl["only-if-cached"]; ok { resp = newGatewayTimeoutResponse(req) } else { - resp, err = dl.execRequest(req) + resp, err = dl.do(req) if err != nil { h.Error(err) return @@ -323,8 +323,8 @@ func (dl *Download) Close() error { return nil } -// execRequest executes the request. -func (dl *Download) execRequest(req *http.Request) (*http.Response, error) { +// do executes the request. +func (dl *Download) do(req *http.Request) (*http.Response, error) { return dl.c.Do(req) } @@ -369,11 +369,11 @@ func (dl *Download) state(req *http.Request) State { fpHeader, _, _ := dl.respCache.Paths(req) f, err := os.Open(fpHeader) if err != nil { - log.Error(msgCloseCacheHeaderBody, lga.File, fpHeader, lga.Err, err) + log.Error(msgCloseCacheHeaderFile, lga.File, fpHeader, lga.Err, err) return Uncached } - defer lg.WarnIfCloseError(log, msgCloseCacheHeaderBody, f) + defer lg.WarnIfCloseError(log, msgCloseCacheHeaderFile, f) cachedResp, err := httpz.ReadResponseHeader(bufio.NewReader(f), nil) if err != nil { diff --git a/libsq/core/ioz/download/httpz.go b/libsq/core/ioz/download/http.go similarity index 99% rename from libsq/core/ioz/download/httpz.go rename to libsq/core/ioz/download/http.go index b7fc01f5e..899e055cd 100644 --- a/libsq/core/ioz/download/httpz.go +++ b/libsq/core/ioz/download/http.go @@ -219,7 +219,7 @@ func newGatewayTimeoutResponse(req *http.Request) *http.Response { // cloneRequest returns a clone of the provided *http.Request. // The clone is a shallow copy of the struct and its Header map. -// (This function copyright goauth2 authors: https://code.google.com/p/goauth2) +// (This function copyright goauth2 authors: https://code.google.com/p/goauth2). func cloneRequest(r *http.Request) *http.Request { // shallow copy of the struct r2 := new(http.Request) diff --git a/libsq/core/ioz/download/respcache.go b/libsq/core/ioz/download/respcache.go index a0acddce2..b98a2ff22 100644 --- a/libsq/core/ioz/download/respcache.go +++ b/libsq/core/ioz/download/respcache.go @@ -22,7 +22,8 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" ) -const msgCloseCacheHeaderBody = "Close cached response header file" +const msgCloseCacheHeaderFile = "Close cached response header file" +const msgCloseCacheBodyFile = "Close cached response body file" // NewRespCache returns a new instance that stores responses in cacheDir. // The caller should call RespCache.Close when finished with the cache. @@ -34,13 +35,18 @@ func NewRespCache(cacheDir string) *RespCache { return c } -// RespCache is a cache for a single http.Response. The response is -// stored in two files, one for the header and one for the body. -// The caller should call RespCache.Close when finished with the cache. +// RespCache is a cache a download. The cached response is +// stored in two files, one for the header and one for the body, with +// a checksum (of the body file) stored in a third file. +// Use RespCache.Paths to access the cache files. type RespCache struct { - mu sync.Mutex + // FIXME: move the mutex to the Download struct? + mu sync.Mutex + + // Deprecated: any cleanup should happen via resp.Body.Close(). clnup *cleanup.Cleanup + // Dir is the directory in which the cache files are stored. Dir string } @@ -53,6 +59,9 @@ func (rc *RespCache) Paths(req *http.Request) (header, body, checksum string) { filepath.Join(rc.Dir, "checksum.txt") } + // This is probably not strictly necessary because we're always + // using GET, but in an earlier incarnation of the code, it was relevant. + // Can probably delete. return filepath.Join(rc.Dir, req.Method+"_header"), filepath.Join(rc.Dir, req.Method+"_body"), filepath.Join(rc.Dir, req.Method+"_checksum.txt") @@ -72,41 +81,51 @@ func (rc *RespCache) Exists(req *http.Request) bool { } // Get returns the cached http.Response for req if present, and nil -// otherwise. +// otherwise. The caller MUST close the returned response body. func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response, error) { rc.mu.Lock() defer rc.mu.Unlock() + log := lg.FromContext(ctx) fpHeader, fpBody, _ := rc.Paths(req) - if !ioz.FileAccessible(fpHeader) { + // If the header file doesn't exist, it's a nil, nil situation. return nil, nil } headerBytes, err := os.ReadFile(fpHeader) if err != nil { - return nil, err + return nil, errz.Wrap(err, "failed to read cached response header file") } bodyFile, err := os.Open(fpBody) if err != nil { - lg.FromContext(ctx).Error("failed to open cached response body", + log.Error("Failed to open cached response body file", lga.File, fpBody, lga.Err, err) - return nil, err + return nil, errz.Wrap(err, "failed to open cached response body file") } - // We need to explicitly close bodyFile at some later point. It won't be + // It won't be // closed via a call to http.Response.Body.Close(). rc.clnup.AddC(bodyFile) + // TODO: consider adding contextio.NewReader? concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) resp, err := http.ReadResponse(bufio.NewReader(concatRdr), req) if err != nil { - lg.WarnIfCloseError(lg.FromContext(ctx), "Close cached response body", bodyFile) + lg.WarnIfCloseError(log, msgCloseCacheBodyFile, bodyFile) return nil, errz.Err(err) } - return resp, nil + // We need to explicitly close bodyFile. To do this (on the happy path), + // we wrap bodyFile in a ReadCloserNotifier, which will close bodyFile + // when resp.Body is closed. Thus, it's critical that the caller + // close the returned resp. + respBody := resp.Body + resp.Body = ioz.ReadCloserNotifier(respBody, func(error) { + lg.WarnIfCloseError(log, msgCloseCacheBodyFile, bodyFile) + }) + return resp, nil } // Checksum returns the checksum of the cached body file, if available. diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index ccc4ccc8c..4a7c7278c 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -396,16 +396,16 @@ func ReadCloserNotifier(rc io.ReadCloser, fn func(closeErr error)) io.ReadCloser if rc == nil || fn == nil { return rc } - return &readCloseNotifier{ReadCloser: rc, fn: fn} + return &readCloserNotifier{ReadCloser: rc, fn: fn} } -type readCloseNotifier struct { +type readCloserNotifier struct { fn func(error) io.ReadCloser } -func (c *readCloseNotifier) Close() error { - err := c.Close() +func (c *readCloserNotifier) Close() error { + err := c.ReadCloser.Close() c.fn(err) return err } diff --git a/testh/tu/tutil.go b/testh/tu/tutil.go index 26d7e50c3..15551c32c 100644 --- a/testh/tu/tutil.go +++ b/testh/tu/tutil.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "os" + "os/exec" "path/filepath" "reflect" "runtime" @@ -398,3 +399,21 @@ func ReadFileToString(t testing.TB, name string) string { require.NoError(t, err) return s } + +// OpenFileCount is a debugging function that returns the count +// of open file handles for the current process via shelling out +// to lsof. This function is skipped on Windows. +// If log is true, the output of lsof is logged. +func OpenFileCount(t *testing.T, log bool) int { + SkipWindows(t, "OpenFileCount not implemented on Windows") + out, err := exec.Command("/bin/sh", "-c", fmt.Sprintf("lsof -p %v", os.Getpid())).Output() + require.NoError(t, err) + lines := strings.Split(string(out), "\n") + count := len(lines) - 1 + msg := fmt.Sprintf("Open files for [%d]: %d", os.Getpid(), count) + if log { + msg += "\n\n" + string(out) + } + t.Log(msg) + return count +} From 8df8d5d44c1a7fd980f5c0b06850ec7685c900cf Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 14:51:13 -0700 Subject: [PATCH 109/195] cleanup --- libsq/core/ioz/download/README.md | 2 +- libsq/core/ioz/download/respcache.go | 2 +- testh/tu/tutil.go | 36 ++++++++++++++++++++++++---- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/libsq/core/ioz/download/README.md b/libsq/core/ioz/download/README.md index 07cd41149..dac847ca4 100644 --- a/libsq/core/ioz/download/README.md +++ b/libsq/core/ioz/download/README.md @@ -1,6 +1,6 @@ # ACKNOWLEDGEMENT -This package a heavily-modified fork +This `download` package is a heavily-modified fork of [`gregjones/httpcache`](https://github.com/gregjones/httpcache). diff --git a/libsq/core/ioz/download/respcache.go b/libsq/core/ioz/download/respcache.go index b98a2ff22..29bc0e1e7 100644 --- a/libsq/core/ioz/download/respcache.go +++ b/libsq/core/ioz/download/respcache.go @@ -107,7 +107,7 @@ func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response // It won't be // closed via a call to http.Response.Body.Close(). - rc.clnup.AddC(bodyFile) + //rc.clnup.AddC(bodyFile) // TODO: consider adding contextio.NewReader? concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) diff --git a/testh/tu/tutil.go b/testh/tu/tutil.go index 15551c32c..f241a0d43 100644 --- a/testh/tu/tutil.go +++ b/testh/tu/tutil.go @@ -405,11 +405,7 @@ func ReadFileToString(t testing.TB, name string) string { // to lsof. This function is skipped on Windows. // If log is true, the output of lsof is logged. func OpenFileCount(t *testing.T, log bool) int { - SkipWindows(t, "OpenFileCount not implemented on Windows") - out, err := exec.Command("/bin/sh", "-c", fmt.Sprintf("lsof -p %v", os.Getpid())).Output() - require.NoError(t, err) - lines := strings.Split(string(out), "\n") - count := len(lines) - 1 + count, out := doOpenFileCount(t) msg := fmt.Sprintf("Open files for [%d]: %d", os.Getpid(), count) if log { msg += "\n\n" + string(out) @@ -417,3 +413,33 @@ func OpenFileCount(t *testing.T, log bool) int { t.Log(msg) return count } + +func doOpenFileCount(t *testing.T) (count int, out string) { + SkipWindows(t, "OpenFileCount not implemented on Windows") + b, err := exec.Command("/bin/sh", "-c", fmt.Sprintf("lsof -p %v", os.Getpid())).Output() + require.NoError(t, err) + lines := strings.Split(string(b), "\n") + count = len(lines) - 1 + return count, string(b) +} + +// DiffOpenFileCount is a debugging function that compares the +// open file count at the start of the test with the count at +// the end of the test (via t.Cleanup). This function is skipped on Windows. +func DiffOpenFileCount(t *testing.T, log bool) { + openingCount, openingOut := doOpenFileCount(t) + if log { + t.Logf("START: Open files for [%d]: %d\n\n%s", os.Getpid(), openingCount, openingOut) + } + t.Cleanup(func() { + closingCount, closingOut := doOpenFileCount(t) + if log { + t.Logf("END: Open files for [%d]: %d\n\n%s", os.Getpid(), closingCount, closingOut) + } + if openingCount != closingCount { + t.Logf("Open file count changed from %d to %d", openingCount, closingCount) + } else { + t.Logf("Open file count unchanged: %d", openingCount) + } + }) +} From 9359f4d4c5234f732a2609f45d9c950327761f50 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 16:45:15 -0700 Subject: [PATCH 110/195] cleanup --- .../ioz/download/{respcache.go => cache.go} | 122 +++++++----------- libsq/core/ioz/download/download.go | 17 +-- libsq/core/ioz/download/download_test.go | 20 +++ libsq/core/ioz/download/httpcache_test.go | 4 +- 4 files changed, 71 insertions(+), 92 deletions(-) rename libsq/core/ioz/download/{respcache.go => cache.go} (62%) diff --git a/libsq/core/ioz/download/respcache.go b/libsq/core/ioz/download/cache.go similarity index 62% rename from libsq/core/ioz/download/respcache.go rename to libsq/core/ioz/download/cache.go index 29bc0e1e7..fb3ba746d 100644 --- a/libsq/core/ioz/download/respcache.go +++ b/libsq/core/ioz/download/cache.go @@ -15,7 +15,6 @@ import ( "path/filepath" "sync" - "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" @@ -25,54 +24,41 @@ import ( const msgCloseCacheHeaderFile = "Close cached response header file" const msgCloseCacheBodyFile = "Close cached response body file" -// NewRespCache returns a new instance that stores responses in cacheDir. -// The caller should call RespCache.Close when finished with the cache. -func NewRespCache(cacheDir string) *RespCache { - c := &RespCache{ - Dir: cacheDir, - clnup: cleanup.New(), - } - return c -} - -// RespCache is a cache a download. The cached response is +// Cache is a cache for a individual download. The cached response is // stored in two files, one for the header and one for the body, with // a checksum (of the body file) stored in a third file. -// Use RespCache.Paths to access the cache files. -type RespCache struct { +// Use Cache.Paths to access the cache files. +type Cache struct { // FIXME: move the mutex to the Download struct? mu sync.Mutex - // Deprecated: any cleanup should happen via resp.Body.Close(). - clnup *cleanup.Cleanup - - // Dir is the directory in which the cache files are stored. - Dir string + // dir is the directory in which the cache files are stored. + dir string } // Paths returns the paths to the header, body, and checksum files for req. // It is not guaranteed that they exist. -func (rc *RespCache) Paths(req *http.Request) (header, body, checksum string) { +func (c *Cache) Paths(req *http.Request) (header, body, checksum string) { if req == nil || req.Method == http.MethodGet { - return filepath.Join(rc.Dir, "header"), - filepath.Join(rc.Dir, "body"), - filepath.Join(rc.Dir, "checksum.txt") + return filepath.Join(c.dir, "header"), + filepath.Join(c.dir, "body"), + filepath.Join(c.dir, "checksum.txt") } // This is probably not strictly necessary because we're always // using GET, but in an earlier incarnation of the code, it was relevant. // Can probably delete. - return filepath.Join(rc.Dir, req.Method+"_header"), - filepath.Join(rc.Dir, req.Method+"_body"), - filepath.Join(rc.Dir, req.Method+"_checksum.txt") + return filepath.Join(c.dir, req.Method+"_header"), + filepath.Join(c.dir, req.Method+"_body"), + filepath.Join(c.dir, req.Method+"_checksum.txt") } // Exists returns true if the cache contains a response for req. -func (rc *RespCache) Exists(req *http.Request) bool { - rc.mu.Lock() - defer rc.mu.Unlock() +func (c *Cache) Exists(req *http.Request) bool { + c.mu.Lock() + defer c.mu.Unlock() - fpHeader, _, _ := rc.Paths(req) + fpHeader, _, _ := c.Paths(req) fi, err := os.Stat(fpHeader) if err != nil { return false @@ -82,12 +68,12 @@ func (rc *RespCache) Exists(req *http.Request) bool { // Get returns the cached http.Response for req if present, and nil // otherwise. The caller MUST close the returned response body. -func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response, error) { - rc.mu.Lock() - defer rc.mu.Unlock() +func (c *Cache) Get(ctx context.Context, req *http.Request) (*http.Response, error) { + c.mu.Lock() + defer c.mu.Unlock() log := lg.FromContext(ctx) - fpHeader, fpBody, _ := rc.Paths(req) + fpHeader, fpBody, _ := c.Paths(req) if !ioz.FileAccessible(fpHeader) { // If the header file doesn't exist, it's a nil, nil situation. return nil, nil @@ -105,11 +91,7 @@ func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response return nil, errz.Wrap(err, "failed to open cached response body file") } - // It won't be - // closed via a call to http.Response.Body.Close(). - //rc.clnup.AddC(bodyFile) - - // TODO: consider adding contextio.NewReader? + // FIXME: consider adding contextio.NewReader? concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) resp, err := http.ReadResponse(bufio.NewReader(concatRdr), req) if err != nil { @@ -121,20 +103,19 @@ func (rc *RespCache) Get(ctx context.Context, req *http.Request) (*http.Response // we wrap bodyFile in a ReadCloserNotifier, which will close bodyFile // when resp.Body is closed. Thus, it's critical that the caller // close the returned resp. - respBody := resp.Body - resp.Body = ioz.ReadCloserNotifier(respBody, func(error) { + resp.Body = ioz.ReadCloserNotifier(resp.Body, func(error) { lg.WarnIfCloseError(log, msgCloseCacheBodyFile, bodyFile) }) return resp, nil } // Checksum returns the checksum of the cached body file, if available. -func (rc *RespCache) Checksum(req *http.Request) (sum checksum.Checksum, ok bool) { - if rc == nil || req == nil { +func (c *Cache) Checksum(req *http.Request) (sum checksum.Checksum, ok bool) { + if c == nil || req == nil { return "", false } - _, _, fp := rc.Paths(req) + _, _, fp := c.Paths(req) if !ioz.FileAccessible(fp) { return "", false } @@ -155,41 +136,28 @@ func (rc *RespCache) Checksum(req *http.Request) (sum checksum.Checksum, ok bool return sum, ok } -// Close closes the cache, freeing any resources it holds. Note that -// it does not delete the cache: for that, see RespCache.Delete. -func (rc *RespCache) Close() error { - rc.mu.Lock() - defer rc.mu.Unlock() - - err := rc.clnup.Run() - rc.clnup = cleanup.New() - return err -} - // Clear deletes the cache entries from disk. -func (rc *RespCache) Clear(ctx context.Context) error { - if rc == nil { +func (c *Cache) Clear(ctx context.Context) error { + if c == nil { return nil } - rc.mu.Lock() - defer rc.mu.Unlock() + c.mu.Lock() + defer c.mu.Unlock() - return rc.doClear(ctx) + return c.doClear(ctx) } -func (rc *RespCache) doClear(ctx context.Context) error { - cleanErr := rc.clnup.Run() - rc.clnup = cleanup.New() - deleteErr := errz.Wrap(os.RemoveAll(rc.Dir), "delete cache dir") - recreateErr := ioz.RequireDir(rc.Dir) - err := errz.Combine(cleanErr, deleteErr, recreateErr) +func (c *Cache) doClear(ctx context.Context) error { + deleteErr := errz.Wrap(os.RemoveAll(c.dir), "delete cache dir") + recreateErr := ioz.RequireDir(c.dir) + err := errz.Combine(deleteErr, recreateErr) if err != nil { lg.FromContext(ctx).Error(msgDeleteCache, - lga.Dir, rc.Dir, lga.Err, err) + lga.Dir, c.dir, lga.Err, err) return err } - lg.FromContext(ctx).Info("Deleted cache dir", lga.Dir, rc.Dir) + lg.FromContext(ctx).Info("Deleted cache dir", lga.Dir, c.dir) return nil } @@ -199,24 +167,24 @@ const msgDeleteCache = "Delete HTTP response cache" // the header cache file is updated. If headerOnly is false and copyWrtr is // non-nil, the response body bytes are copied to that destination, as well as // being written to the cache. The response body is always closed. -func (rc *RespCache) Write(ctx context.Context, resp *http.Response, headerOnly bool, copyWrtr io.WriteCloser) error { - rc.mu.Lock() - defer rc.mu.Unlock() +func (c *Cache) Write(ctx context.Context, resp *http.Response, headerOnly bool, copyWrtr io.WriteCloser) error { + c.mu.Lock() + defer c.mu.Unlock() - return rc.doWrite(ctx, resp, headerOnly, copyWrtr) + return c.doWrite(ctx, resp, headerOnly, copyWrtr) } -func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, +func (c *Cache) doWrite(ctx context.Context, resp *http.Response, headerOnly bool, copyWrtr io.WriteCloser) error { log := lg.FromContext(ctx) defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - if err := ioz.RequireDir(rc.Dir); err != nil { + if err := ioz.RequireDir(c.dir); err != nil { return err } - log.Debug("Writing HTTP response to cache", lga.Dir, rc.Dir, "resp", httpz.ResponseLogValue(resp)) - fpHeader, fpBody, _ := rc.Paths(resp.Request) + log.Debug("Writing HTTP response to cache", lga.Dir, c.dir, "resp", httpz.ResponseLogValue(resp)) + fpHeader, fpBody, _ := c.Paths(resp.Request) headerBytes, err := httputil.DumpResponse(resp, false) if err != nil { @@ -264,7 +232,7 @@ func (rc *RespCache) doWrite(ctx context.Context, resp *http.Response, return errz.Wrap(err, "failed to compute checksum for cache body file") } - if err = checksum.WriteFile(filepath.Join(rc.Dir, "checksum.txt"), sum, "body"); err != nil { + if err = checksum.WriteFile(filepath.Join(c.dir, "checksum.txt"), sum, "body"); err != nil { return errz.Wrap(err, "failed to write checksum file for cache body") } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index f75bf7fb5..477e9f87f 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -87,7 +87,7 @@ type Download struct { c *http.Client - respCache *RespCache + respCache *Cache // markCachedResponses, if true, indicates that responses returned from the // cache will be given an extra header, X-From-Cache. @@ -122,7 +122,7 @@ func New(c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) } if !t.disableCaching { - t.respCache = NewRespCache(cacheDir) + t.respCache = &Cache{dir: cacheDir} } return t, nil @@ -265,7 +265,7 @@ func (dl *Download) get(req *http.Request, h Handler) { if resp == cachedResp { lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) if err = dl.respCache.Write(ctx, resp, true, nil); err != nil { - log.Error("Failed to update cache header", lga.Dir, dl.respCache.Dir, lga.Err, err) + log.Error("Failed to update cache header", lga.Dir, dl.respCache.dir, lga.Err, err) h.Error(err) return } @@ -282,7 +282,7 @@ func (dl *Download) get(req *http.Request, h Handler) { defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) if err = dl.respCache.Write(req.Context(), resp, false, destWrtr); err != nil { - log.Error("Failed to write download cache", lga.Dir, dl.respCache.Dir, lga.Err, err) + log.Error("Failed to write download cache", lga.Dir, dl.respCache.dir, lga.Err, err) if errFn != nil { errFn(err) } @@ -314,15 +314,6 @@ func (dl *Download) get(req *http.Request, h Handler) { return } -// Close frees any resources held by the Download. It does not delete -// the cache from disk. For that, see Download.Clear. -func (dl *Download) Close() error { - if dl.respCache != nil { - return dl.respCache.Close() - } - return nil -} - // do executes the request. func (dl *Download) do(req *http.Request) (*http.Response, error) { return dl.c.Do(req) diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index e736b04b4..db1166caa 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -1,8 +1,10 @@ package download_test import ( + "bufio" "bytes" "context" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/testh/tu" @@ -13,6 +15,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "sync" "testing" "time" @@ -236,3 +239,20 @@ func newTestHandler(log *slog.Logger) *testHandler { } return th } + +func TestMisc(t *testing.T) { + br := bufio.NewReader(strings.NewReader("huzzah")) + + cr := contextio.NewReader(context.TODO(), br) + + t.Logf("cr: %T", cr) + + var wt io.WriterTo + wt = br + _ = wt + + var ok bool + wt, ok = cr.(io.WriterTo) + require.True(t, ok) + +} diff --git a/libsq/core/ioz/download/httpcache_test.go b/libsq/core/ioz/download/httpcache_test.go index ebb081461..05f1ba944 100644 --- a/libsq/core/ioz/download/httpcache_test.go +++ b/libsq/core/ioz/download/httpcache_test.go @@ -159,8 +159,8 @@ package download //} // //func resetTest(t testing.TB) { -// s.transport.RespCache = NewRespCache(t.TempDir()) -// //s.transport.RespCache.Delete() +// s.transport.Cache = NewRespCache(t.TempDir()) +// //s.transport.Cache.Delete() // clock = &realClock{} //} // From b0cb87f0d3f8645c2841a36a695734cc617bddfd Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 17:12:18 -0700 Subject: [PATCH 111/195] wip cleanup --- libsq/core/ioz/download/cache.go | 57 ++++++++++++++---------- libsq/core/ioz/download/download.go | 20 ++++----- libsq/core/ioz/download/download_test.go | 2 +- 3 files changed, 45 insertions(+), 34 deletions(-) diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index fb3ba746d..6f4d9b687 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -24,11 +24,11 @@ import ( const msgCloseCacheHeaderFile = "Close cached response header file" const msgCloseCacheBodyFile = "Close cached response body file" -// Cache is a cache for a individual download. The cached response is +// cache is a cache for a individual download. The cached response is // stored in two files, one for the header and one for the body, with // a checksum (of the body file) stored in a third file. -// Use Cache.Paths to access the cache files. -type Cache struct { +// Use cache.paths to access the cache files. +type cache struct { // FIXME: move the mutex to the Download struct? mu sync.Mutex @@ -36,9 +36,9 @@ type Cache struct { dir string } -// Paths returns the paths to the header, body, and checksum files for req. +// paths returns the paths to the header, body, and checksum files for req. // It is not guaranteed that they exist. -func (c *Cache) Paths(req *http.Request) (header, body, checksum string) { +func (c *cache) paths(req *http.Request) (header, body, checksum string) { if req == nil || req.Method == http.MethodGet { return filepath.Join(c.dir, "header"), filepath.Join(c.dir, "body"), @@ -53,12 +53,12 @@ func (c *Cache) Paths(req *http.Request) (header, body, checksum string) { filepath.Join(c.dir, req.Method+"_checksum.txt") } -// Exists returns true if the cache contains a response for req. -func (c *Cache) Exists(req *http.Request) bool { +// exists returns true if the cache contains a response for req. +func (c *cache) exists(req *http.Request) bool { c.mu.Lock() defer c.mu.Unlock() - fpHeader, _, _ := c.Paths(req) + fpHeader, _, _ := c.paths(req) fi, err := os.Stat(fpHeader) if err != nil { return false @@ -68,12 +68,12 @@ func (c *Cache) Exists(req *http.Request) bool { // Get returns the cached http.Response for req if present, and nil // otherwise. The caller MUST close the returned response body. -func (c *Cache) Get(ctx context.Context, req *http.Request) (*http.Response, error) { +func (c *cache) get(ctx context.Context, req *http.Request) (*http.Response, error) { c.mu.Lock() defer c.mu.Unlock() log := lg.FromContext(ctx) - fpHeader, fpBody, _ := c.Paths(req) + fpHeader, fpBody, _ := c.paths(req) if !ioz.FileAccessible(fpHeader) { // If the header file doesn't exist, it's a nil, nil situation. return nil, nil @@ -91,9 +91,13 @@ func (c *Cache) Get(ctx context.Context, req *http.Request) (*http.Response, err return nil, errz.Wrap(err, "failed to open cached response body file") } - // FIXME: consider adding contextio.NewReader? - concatRdr := io.MultiReader(bytes.NewReader(headerBytes), bodyFile) - resp, err := http.ReadResponse(bufio.NewReader(concatRdr), req) + // Now it's time for the Matroyshka readers. First we concatenate the + // header and body via io.MultiReader. Then, we wrap that in + // a contextio.NewReader, for context-awareness. Finally, + // http.ReadResponse requires a bufio.Reader, so we wrap the + // context reader via bufio.NewReader, and then we're ready to go. + r := contextio.NewReader(ctx, io.MultiReader(bytes.NewReader(headerBytes), bodyFile)) + resp, err := http.ReadResponse(bufio.NewReader(r), req) if err != nil { lg.WarnIfCloseError(log, msgCloseCacheBodyFile, bodyFile) return nil, errz.Err(err) @@ -109,13 +113,13 @@ func (c *Cache) Get(ctx context.Context, req *http.Request) (*http.Response, err return resp, nil } -// Checksum returns the checksum of the cached body file, if available. -func (c *Cache) Checksum(req *http.Request) (sum checksum.Checksum, ok bool) { +// checksum returns the checksum of the cached body file, if available. +func (c *cache) checksum(req *http.Request) (sum checksum.Checksum, ok bool) { if c == nil || req == nil { return "", false } - _, _, fp := c.Paths(req) + _, _, fp := c.paths(req) if !ioz.FileAccessible(fp) { return "", false } @@ -137,7 +141,7 @@ func (c *Cache) Checksum(req *http.Request) (sum checksum.Checksum, ok bool) { } // Clear deletes the cache entries from disk. -func (c *Cache) Clear(ctx context.Context) error { +func (c *cache) Clear(ctx context.Context) error { if c == nil { return nil } @@ -147,7 +151,7 @@ func (c *Cache) Clear(ctx context.Context) error { return c.doClear(ctx) } -func (c *Cache) doClear(ctx context.Context) error { +func (c *cache) doClear(ctx context.Context) error { deleteErr := errz.Wrap(os.RemoveAll(c.dir), "delete cache dir") recreateErr := ioz.RequireDir(c.dir) err := errz.Combine(deleteErr, recreateErr) @@ -163,18 +167,24 @@ func (c *Cache) doClear(ctx context.Context) error { const msgDeleteCache = "Delete HTTP response cache" -// Write writes resp header and body to the cache. If headerOnly is true, only +// write writes resp header and body to the cache. If headerOnly is true, only // the header cache file is updated. If headerOnly is false and copyWrtr is // non-nil, the response body bytes are copied to that destination, as well as -// being written to the cache. The response body is always closed. -func (c *Cache) Write(ctx context.Context, resp *http.Response, headerOnly bool, copyWrtr io.WriteCloser) error { +// being written to the cache. If writing to copyWrtr completes successfully, +// it is closed; if there's an error, copyWrtr is not closed. This is by design; +// the caller is responsible for propagating any write error (as returned by +// this method) to copyWrtr's owner. +// A checksum file, computed from the body file, is also written to disk. The +// response body is always closed. +func (c *cache) write(ctx context.Context, resp *http.Response, + headerOnly bool, copyWrtr io.WriteCloser) error { c.mu.Lock() defer c.mu.Unlock() return c.doWrite(ctx, resp, headerOnly, copyWrtr) } -func (c *Cache) doWrite(ctx context.Context, resp *http.Response, +func (c *cache) doWrite(ctx context.Context, resp *http.Response, headerOnly bool, copyWrtr io.WriteCloser) error { log := lg.FromContext(ctx) defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) @@ -184,7 +194,7 @@ func (c *Cache) doWrite(ctx context.Context, resp *http.Response, } log.Debug("Writing HTTP response to cache", lga.Dir, c.dir, "resp", httpz.ResponseLogValue(resp)) - fpHeader, fpBody, _ := c.Paths(resp.Request) + fpHeader, fpBody, _ := c.paths(resp.Request) headerBytes, err := httputil.DumpResponse(resp, false) if err != nil { @@ -215,6 +225,7 @@ func (c *Cache) doWrite(ctx context.Context, resp *http.Response, var written int64 written, err = io.Copy(cacheFile, cr) if err != nil { + log.Error("Cache write: io.Copy failed", lga.Err, err) lg.WarnIfCloseError(log, "Close cache body file", cacheFile) return err diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 477e9f87f..bfc6431cd 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -87,10 +87,10 @@ type Download struct { c *http.Client - respCache *Cache + respCache *cache // markCachedResponses, if true, indicates that responses returned from the - // cache will be given an extra header, X-From-Cache. + // cache will be given an extra header, X-From-cache. markCachedResponses bool disableCaching bool @@ -122,7 +122,7 @@ func New(c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) } if !t.disableCaching { - t.respCache = &Cache{dir: cacheDir} + t.respCache = &cache{dir: cacheDir} } return t, nil @@ -159,7 +159,7 @@ func (dl *Download) get(req *http.Request, h Handler) { ctx := req.Context() log := lg.FromContext(ctx) log.Debug("Get download", lga.URL, dl.url) - _, fpBody, _ := dl.respCache.Paths(req) + _, fpBody, _ := dl.respCache.paths(req) state := dl.state(req) if state == Fresh { @@ -171,7 +171,7 @@ func (dl *Download) get(req *http.Request, h Handler) { cacheable := dl.isCacheable(req) var cachedResp *http.Response if cacheable { - cachedResp, err = dl.respCache.Get(req.Context(), req) + cachedResp, err = dl.respCache.get(req.Context(), req) } else { // Need to invalidate an existing value if err = dl.respCache.Clear(req.Context()); err != nil { @@ -264,7 +264,7 @@ func (dl *Download) get(req *http.Request, h Handler) { if resp == cachedResp { lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - if err = dl.respCache.Write(ctx, resp, true, nil); err != nil { + if err = dl.respCache.write(ctx, resp, true, nil); err != nil { log.Error("Failed to update cache header", lga.Dir, dl.respCache.dir, lga.Err, err) h.Error(err) return @@ -281,7 +281,7 @@ func (dl *Download) get(req *http.Request, h Handler) { } defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - if err = dl.respCache.Write(req.Context(), resp, false, destWrtr); err != nil { + if err = dl.respCache.write(req.Context(), resp, false, destWrtr); err != nil { log.Error("Failed to write download cache", lga.Dir, dl.respCache.dir, lga.Err, err) if errFn != nil { errFn(err) @@ -353,11 +353,11 @@ func (dl *Download) state(req *http.Request) State { ctx := req.Context() log := lg.FromContext(ctx) - if !dl.respCache.Exists(req) { + if !dl.respCache.exists(req) { return Uncached } - fpHeader, _, _ := dl.respCache.Paths(req) + fpHeader, _, _ := dl.respCache.paths(req) f, err := os.Open(fpHeader) if err != nil { log.Error(msgCloseCacheHeaderFile, lga.File, fpHeader, lga.Err, err) @@ -383,7 +383,7 @@ func (dl *Download) Checksum(ctx context.Context) (sum checksum.Checksum, ok boo req := dl.mustRequest(ctx) - _, _, fp := dl.respCache.Paths(req) + _, _, fp := dl.respCache.paths(req) if !ioz.FileAccessible(fp) { return "", false } diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index db1166caa..bbca7f2e4 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -241,6 +241,7 @@ func newTestHandler(log *slog.Logger) *testHandler { } func TestMisc(t *testing.T) { + // FIXME: delete br := bufio.NewReader(strings.NewReader("huzzah")) cr := contextio.NewReader(context.TODO(), br) @@ -254,5 +255,4 @@ func TestMisc(t *testing.T) { var ok bool wt, ok = cr.(io.WriterTo) require.True(t, ok) - } From 00252644375c3f0540b303c783df2ea926196fa0 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 17:49:12 -0700 Subject: [PATCH 112/195] Refactored with ioz.WriteErrorCloser --- libsq/core/ioz/download/cache.go | 43 +++++++++++++++--------- libsq/core/ioz/download/download.go | 28 +++++++-------- libsq/core/ioz/download/download_test.go | 15 ++++++--- libsq/core/ioz/ioz.go | 29 ++++++++++++++++ 4 files changed, 81 insertions(+), 34 deletions(-) diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index 6f4d9b687..49034d10c 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -171,13 +171,11 @@ const msgDeleteCache = "Delete HTTP response cache" // the header cache file is updated. If headerOnly is false and copyWrtr is // non-nil, the response body bytes are copied to that destination, as well as // being written to the cache. If writing to copyWrtr completes successfully, -// it is closed; if there's an error, copyWrtr is not closed. This is by design; -// the caller is responsible for propagating any write error (as returned by -// this method) to copyWrtr's owner. +// it is closed; if there's an error, copyWrtr.Error is invoked. // A checksum file, computed from the body file, is also written to disk. The // response body is always closed. func (c *cache) write(ctx context.Context, resp *http.Response, - headerOnly bool, copyWrtr io.WriteCloser) error { + headerOnly bool, copyWrtr ioz.WriteErrorCloser) error { c.mu.Lock() defer c.mu.Unlock() @@ -185,11 +183,20 @@ func (c *cache) write(ctx context.Context, resp *http.Response, } func (c *cache) doWrite(ctx context.Context, resp *http.Response, - headerOnly bool, copyWrtr io.WriteCloser) error { + headerOnly bool, copyWrtr ioz.WriteErrorCloser) (err error) { log := lg.FromContext(ctx) - defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - if err := ioz.RequireDir(c.dir); err != nil { + defer func() { + lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) + if err == nil { + return + } + if err != nil && copyWrtr != nil { + copyWrtr.Error(err) + } + }() + + if err = ioz.RequireDir(c.dir); err != nil { return err } @@ -223,19 +230,25 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, } var written int64 - written, err = io.Copy(cacheFile, cr) - if err != nil { - + if written, err = io.Copy(cacheFile, cr); err != nil { + err = errz.Err(err) log.Error("Cache write: io.Copy failed", lga.Err, err) - lg.WarnIfCloseError(log, "Close cache body file", cacheFile) + lg.WarnIfCloseError(log, msgCloseCacheBodyFile, cacheFile) return err } - if copyWrtr != nil { - lg.WarnIfCloseError(log, "Close copy writer", copyWrtr) - } if err = cacheFile.Close(); err != nil { - return errz.Err(err) + cacheFile = nil + err = errz.Err(err) + return err + } + + if copyWrtr != nil { + if err = copyWrtr.Close(); err != nil { + err = errz.Err(err) + copyWrtr = nil + return err + } } sum, err := checksum.ForFile(fpBody) diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index bfc6431cd..43fa161ac 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -136,12 +136,12 @@ type Handler struct { Cached func(fp string) // Uncached is invoked when the download is not cached. The handler should - // return an io.WriterCloser, which the download contents will be written + // return an ioz.WriteErrorCloser, which the download contents will be written // to (as well as being written to the disk cache). On success, the dest - // io.WriteCloser is closed. If an error occurs during download or - // writing, errFn is invoked, and dest is not closed. If the handler returns - // a nil dest io.WriteCloser, the Download will log a warning and return. - Uncached func() (dest io.WriteCloser, errFn func(error)) + // io.WriteCloser is closed. If an error occurs during download or writing, + // WriteErrorCloser.Error is invoked (but Close is not invoked). If the + // handler returns a nil dest, the Download will log a warning and return. + Uncached func() (dest ioz.WriteErrorCloser) // Error is invoked on any error, other than writing to the destination // io.WriteCloser returned by Handler.Uncached, which has its own error @@ -274,7 +274,7 @@ func (dl *Download) get(req *http.Request, h Handler) { } // I'm not sure if this logic is even reachable? - destWrtr, errFn := h.Uncached() + destWrtr := h.Uncached() if destWrtr == nil { log.Warn(msgNilDestWriter) return @@ -283,31 +283,29 @@ func (dl *Download) get(req *http.Request, h Handler) { defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) if err = dl.respCache.write(req.Context(), resp, false, destWrtr); err != nil { log.Error("Failed to write download cache", lga.Dir, dl.respCache.dir, lga.Err, err) - if errFn != nil { - errFn(err) - } + //destWrtr.Error(err) } return } else { lg.WarnIfError(log, "Delete resp cache", dl.respCache.Clear(req.Context())) } - // It's not cacheable, so we need to write it to the copyWrtr. - copyWrtr, errFn := h.Uncached() - if copyWrtr == nil { + // It's not cacheable, so we need to write it to the destWrtr. + destWrtr := h.Uncached() + if destWrtr == nil { log.Warn(msgNilDestWriter) return } cr := contextio.NewReader(ctx, resp.Body) defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - _, err = io.Copy(copyWrtr, cr) + _, err = io.Copy(destWrtr, cr) if err != nil { log.Error("Failed to copy download to dest writer", lga.Err, err) - errFn(err) + destWrtr.Error(err) return } - if err = copyWrtr.Close(); err != nil { + if err = destWrtr.Close(); err != nil { log.Error("Failed to close dest writer", lga.Err, err) } diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index bbca7f2e4..25fafbc53 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -152,7 +152,7 @@ func TestDownload_New(t *testing.T) { require.NoError(t, err) t.Logf("cacheDir: %s", cacheDir) - dl, err := download.New(nil, dlURL, cacheDir) + dl, err := download.New(httpz.NewDefaultClient(), dlURL, cacheDir) require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) require.Equal(t, download.Uncached, dl.State(ctx)) @@ -163,6 +163,7 @@ func TestDownload_New(t *testing.T) { h := newTestHandler(log.With("origin", "handler")) dl.Get(ctx, h.Handler) require.Empty(t, h.errors) + require.Empty(t, h.writeErrs) require.Empty(t, h.cacheFiles) require.Equal(t, sizeActorCSV, int64(h.bufs[0].Len())) require.Equal(t, download.Fresh, dl.State(ctx)) @@ -173,6 +174,7 @@ func TestDownload_New(t *testing.T) { h.reset() dl.Get(ctx, h.Handler) require.Empty(t, h.errors) + require.Empty(t, h.writeErrs) require.Empty(t, h.bufs) require.NotEmpty(t, h.cacheFiles) gotFileBytes, err := os.ReadFile(h.cacheFiles[0]) @@ -188,6 +190,11 @@ func TestDownload_New(t *testing.T) { sum, ok = dl.Checksum(ctx) require.False(t, ok) require.Empty(t, sum) + + h.reset() + dl.Get(ctx, h.Handler) + require.Empty(t, h.errors) + require.Empty(t, h.writeErrs) } type testHandler struct { @@ -218,17 +225,17 @@ func newTestHandler(log *slog.Logger) *testHandler { th.cacheFiles = append(th.cacheFiles, fp) } - th.Uncached = func() (io.WriteCloser, func(error)) { + th.Uncached = func() ioz.WriteErrorCloser { log.Info("Uncached") th.mu.Lock() defer th.mu.Unlock() buf := &bytes.Buffer{} th.bufs = append(th.bufs, buf) - return ioz.WriteCloser(buf), func(err error) { + return ioz.NewFuncWriteErrorCloser(ioz.WriteCloser(buf), func(err error) { th.mu.Lock() defer th.mu.Unlock() th.writeErrs = append(th.writeErrs, err) - } + }) } th.Error = func(err error) { diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 4a7c7278c..4f38a2c84 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -432,3 +432,32 @@ func WriteToFile(ctx context.Context, fp string, r io.Reader) (written int64, er return written, errz.Err(err) } + +// WriteErrorCloser supplements io.WriteCloser with an Error method, indicating +// to the io.WriteCloser that an upstream error has interrupted the writing +// operation. Note that clients should invoke only one of Close or Error. +type WriteErrorCloser interface { + io.WriteCloser + + // Error indicates that an upstream error has interrupted the + // writing operation. + Error(err error) +} + +type writeErrorCloser struct { + fn func(error) + io.WriteCloser +} + +// Error implements WriteErrorCloser.Error. +func (w *writeErrorCloser) Error(err error) { + if w.fn != nil { + w.fn(err) + } +} + +// NewFuncWriteErrorCloser returns a new WriteErrorCloser that wraps w, and +// invokes non-nil fn when WriteErrorCloser.Error is called. +func NewFuncWriteErrorCloser(w io.WriteCloser, fn func(error)) WriteErrorCloser { + return &writeErrorCloser{WriteCloser: w, fn: fn} +} From 227bb2bbb04e5cef9a79df056bbbee91017a71c2 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 19:01:43 -0700 Subject: [PATCH 113/195] added sq x download cmd --- cli/cli.go | 3 +- cli/cmd_x.go | 73 +++++++++++-- libsq/core/ioz/download/download.go | 26 +---- libsq/core/ioz/download/download_test.go | 127 +++++------------------ libsq/core/ioz/download/handler.go | 93 +++++++++++++++++ libsq/source/files.go | 52 ---------- libsq/source/lock.go | 35 +++++++ testh/tu/{tutil.go => tu.go} | 12 ++- testh/tu/{tutil_test.go => tu_test.go} | 0 9 files changed, 235 insertions(+), 186 deletions(-) create mode 100644 libsq/core/ioz/download/handler.go create mode 100644 libsq/source/lock.go rename testh/tu/{tutil.go => tu.go} (97%) rename testh/tu/{tutil_test.go => tu_test.go} (100%) diff --git a/cli/cli.go b/cli/cli.go index 60609787d..854339531 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -257,7 +257,8 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { xCmd := addCmd(ru, rootCmd, newXCmd()) addCmd(ru, xCmd, newXLockSrcCmd()) - addCmd(ru, xCmd, newXDevTestCmd()) + addCmd(ru, xCmd, newXProgressCmd()) + addCmd(ru, xCmd, newXDownloadCmd()) return rootCmd } diff --git a/cli/cmd_x.go b/cli/cmd_x.go index a4e027f35..8f2cc8b54 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -3,6 +3,11 @@ package cli import ( "bufio" "fmt" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/ioz/download" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/sq/libsq/source" + "net/url" "os" "time" @@ -80,19 +85,19 @@ func execXLockSrcCmd(cmd *cobra.Command, args []string) error { return nil } -func newXDevTestCmd() *cobra.Command { +func newXProgressCmd() *cobra.Command { cmd := &cobra.Command{ - Use: "dev-test", - Short: "Execute some dev test code", + Use: "progress", + Short: "Execute progress test code", Hidden: true, - RunE: execXDevTestCmd, - Example: ` $ sq x dev-test`, + RunE: execXProgress, + Example: ` $ sq x progress`, } return cmd } -func execXDevTestCmd(cmd *cobra.Command, _ []string) error { +func execXProgress(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() log := lg.FromContext(ctx) ru := run.FromContext(ctx) @@ -124,6 +129,62 @@ func execXDevTestCmd(cmd *cobra.Command, _ []string) error { return ctx.Err() } +func newXDownloadCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "download URL", + Short: "Download a file", + Hidden: true, + Args: cobra.ExactArgs(1), + RunE: execXDownloadCmd, + Example: ` $ sq x download https://sq.io/testdata/actor.csv + + # Download a big-ass file + $ sq x download https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv +`, + } + + return cmd +} + +func execXDownloadCmd(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + log := lg.FromContext(ctx) + ru := run.FromContext(ctx) + + u, err := url.ParseRequestURI(args[0]) + if err != nil { + return err + } + + sum := checksum.Sum([]byte(u.String())) + cacheDir, err := ru.Files.CacheDirFor(&source.Source{Handle: "@download_" + sum}) + if err != nil { + return err + } + + dl, err := download.New(httpz.NewDefaultClient(), u.String(), cacheDir) + if err != nil { + return err + } + + h := download.NewSinkHandler(log.With("origin", "handler")) + dl.Get(ctx, h.Handler) + + switch { + case len(h.Errors) > 0: + return h.Errors[0] + case len(h.WriteErrors) > 0: + return h.WriteErrors[0] + case len(h.CachedFiles) > 0: + fmt.Fprintf(ru.Out, "Cached: %s\n", h.CachedFiles[0]) + return nil + case len(h.UncachedBufs) > 0: + fmt.Fprintf(ru.Out, "Uncached: %d bytes\n", h.UncachedBufs[0].Len()) + } + + return nil +} + func pressEnter() <-chan struct{} { done := make(chan struct{}) go func() { diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 43fa161ac..6ff6935a0 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -128,27 +128,6 @@ func New(c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) return t, nil } -// Handler is a callback invoked by Download.Get. Exactly one of the -// handler functions will be invoked, exactly one time. -type Handler struct { - // Cached is invoked when the download is already cached on disk. The - // fp arg is the path to the downloaded file. - Cached func(fp string) - - // Uncached is invoked when the download is not cached. The handler should - // return an ioz.WriteErrorCloser, which the download contents will be written - // to (as well as being written to the disk cache). On success, the dest - // io.WriteCloser is closed. If an error occurs during download or writing, - // WriteErrorCloser.Error is invoked (but Close is not invoked). If the - // handler returns a nil dest, the Download will log a warning and return. - Uncached func() (dest ioz.WriteErrorCloser) - - // Error is invoked on any error, other than writing to the destination - // io.WriteCloser returned by Handler.Uncached, which has its own error - // handling mechanism. - Error func(err error) -} - // Get gets the download, invoking Handler as appropriate. func (dl *Download) Get(ctx context.Context, h Handler) { req := dl.mustRequest(ctx) @@ -290,7 +269,8 @@ func (dl *Download) get(req *http.Request, h Handler) { lg.WarnIfError(log, "Delete resp cache", dl.respCache.Clear(req.Context())) } - // It's not cacheable, so we need to write it to the destWrtr. + // It's not cacheable, so we need to write it to the destWrtr, + // and skip the cache. destWrtr := h.Uncached() if destWrtr == nil { log.Warn(msgNilDestWriter) @@ -298,7 +278,7 @@ func (dl *Download) get(req *http.Request, h Handler) { } cr := contextio.NewReader(ctx, resp.Body) - defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) + defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cr.(io.ReadCloser)) _, err = io.Copy(destWrtr, cr) if err != nil { log.Error("Failed to copy download to dest writer", lga.Err, err) diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index 25fafbc53..1a9a6a386 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -1,35 +1,27 @@ package download_test import ( - "bufio" - "bytes" "context" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/testh/tu" "github.com/stretchr/testify/assert" - "io" - "log/slog" "net/http" "net/http/httptest" "os" "path/filepath" - "strings" - "sync" "testing" "time" "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/sq/libsq/core/lg" "github.com/stretchr/testify/require" ) const ( - urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" urlActorCSV = "https://sq.io/testdata/actor.csv" + urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" sizeActorCSV = int64(7641) ) @@ -88,28 +80,28 @@ func TestDownload_redirect(t *testing.T) { dl, err := download.New(httpz.NewDefaultClient(), loc, cacheDir) require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) - h := newTestHandler(log.With("origin", "handler")) + h := download.NewSinkHandler(log.With("origin", "handler")) dl.Get(ctx, h.Handler) - require.Empty(t, h.errors) - gotBody := h.bufs[0].String() + require.Empty(t, h.Errors) + gotBody := h.UncachedBufs[0].String() require.Equal(t, hello, gotBody) - h.reset() + h.Reset() dl.Get(ctx, h.Handler) - require.Empty(t, h.errors) - require.Empty(t, h.bufs) - gotFile := h.cacheFiles[0] + require.Empty(t, h.Errors) + require.Empty(t, h.UncachedBufs) + gotFile := h.CachedFiles[0] t.Logf("got fp: %s", gotFile) gotBody = tu.ReadFileToString(t, gotFile) t.Logf("got body: \n\n%s\n\n", gotBody) require.Equal(t, serveBody, gotBody) - h.reset() + h.Reset() dl.Get(ctx, h.Handler) - require.Empty(t, h.errors) - require.Empty(t, h.bufs) - gotFile = h.cacheFiles[0] + require.Empty(t, h.Errors) + require.Empty(t, h.UncachedBufs) + gotFile = h.CachedFiles[0] t.Logf("got fp: %s", gotFile) gotBody = tu.ReadFileToString(t, gotFile) t.Logf("got body: \n\n%s\n\n", gotBody) @@ -160,24 +152,24 @@ func TestDownload_New(t *testing.T) { require.False(t, ok) require.Empty(t, sum) - h := newTestHandler(log.With("origin", "handler")) + h := download.NewSinkHandler(log.With("origin", "handler")) dl.Get(ctx, h.Handler) - require.Empty(t, h.errors) - require.Empty(t, h.writeErrs) - require.Empty(t, h.cacheFiles) - require.Equal(t, sizeActorCSV, int64(h.bufs[0].Len())) + require.Empty(t, h.Errors) + require.Empty(t, h.WriteErrors) + require.Empty(t, h.CachedFiles) + require.Equal(t, sizeActorCSV, int64(h.UncachedBufs[0].Len())) require.Equal(t, download.Fresh, dl.State(ctx)) sum, ok = dl.Checksum(ctx) require.True(t, ok) require.NotEmpty(t, sum) - h.reset() + h.Reset() dl.Get(ctx, h.Handler) - require.Empty(t, h.errors) - require.Empty(t, h.writeErrs) - require.Empty(t, h.bufs) - require.NotEmpty(t, h.cacheFiles) - gotFileBytes, err := os.ReadFile(h.cacheFiles[0]) + require.Empty(t, h.Errors) + require.Empty(t, h.WriteErrors) + require.Empty(t, h.UncachedBufs) + require.NotEmpty(t, h.CachedFiles) + gotFileBytes, err := os.ReadFile(h.CachedFiles[0]) require.NoError(t, err) require.Equal(t, sizeActorCSV, int64(len(gotFileBytes))) require.Equal(t, download.Fresh, dl.State(ctx)) @@ -191,75 +183,8 @@ func TestDownload_New(t *testing.T) { require.False(t, ok) require.Empty(t, sum) - h.reset() + h.Reset() dl.Get(ctx, h.Handler) - require.Empty(t, h.errors) - require.Empty(t, h.writeErrs) -} - -type testHandler struct { - download.Handler - mu sync.Mutex - log *slog.Logger - errors []error - cacheFiles []string - bufs []*bytes.Buffer - writeErrs []error -} - -func (th *testHandler) reset() { - th.mu.Lock() - defer th.mu.Unlock() - th.errors = nil - th.cacheFiles = nil - th.bufs = nil - th.writeErrs = nil -} - -func newTestHandler(log *slog.Logger) *testHandler { - th := &testHandler{log: log} - th.Cached = func(fp string) { - log.Info("Cached", lga.File, fp) - th.mu.Lock() - defer th.mu.Unlock() - th.cacheFiles = append(th.cacheFiles, fp) - } - - th.Uncached = func() ioz.WriteErrorCloser { - log.Info("Uncached") - th.mu.Lock() - defer th.mu.Unlock() - buf := &bytes.Buffer{} - th.bufs = append(th.bufs, buf) - return ioz.NewFuncWriteErrorCloser(ioz.WriteCloser(buf), func(err error) { - th.mu.Lock() - defer th.mu.Unlock() - th.writeErrs = append(th.writeErrs, err) - }) - } - - th.Error = func(err error) { - log.Info("Error", lga.Err, err) - th.mu.Lock() - defer th.mu.Unlock() - th.errors = append(th.errors, err) - } - return th -} - -func TestMisc(t *testing.T) { - // FIXME: delete - br := bufio.NewReader(strings.NewReader("huzzah")) - - cr := contextio.NewReader(context.TODO(), br) - - t.Logf("cr: %T", cr) - - var wt io.WriterTo - wt = br - _ = wt - - var ok bool - wt, ok = cr.(io.WriterTo) - require.True(t, ok) + require.Empty(t, h.Errors) + require.Empty(t, h.WriteErrors) } diff --git a/libsq/core/ioz/download/handler.go b/libsq/core/ioz/download/handler.go new file mode 100644 index 000000000..50444f76d --- /dev/null +++ b/libsq/core/ioz/download/handler.go @@ -0,0 +1,93 @@ +package download + +import ( + "bytes" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "log/slog" + "sync" +) + +// Handler is a callback invoked by Download.Get. Exactly one of the +// handler functions will be invoked, exactly one time. +type Handler struct { + // Cached is invoked when the download is already cached on disk. The + // fp arg is the path to the downloaded file. + Cached func(fp string) + + // Uncached is invoked when the download is not cached. The handler should + // return an ioz.WriteErrorCloser, which the download contents will be written + // to (as well as being written to the disk cache). On success, the dest + // io.WriteCloser is closed. If an error occurs during download or writing, + // WriteErrorCloser.Error is invoked (but Close is not invoked). If the + // handler returns a nil dest, the Download will log a warning and return. + Uncached func() (dest ioz.WriteErrorCloser) + + // Error is invoked on any error, other than writing to the destination + // io.WriteCloser returned by Handler.Uncached, which has its own error + // handling mechanism. + Error func(err error) +} + +// SinkHandler is a download.Handler that records the results of the callbacks +// it receives. This is useful for testing. +type SinkHandler struct { + Handler + mu sync.Mutex + log *slog.Logger + + // Errors records the errors received via Handler.Error. + Errors []error + + // CachedFiles records the cached files received via Handler.Cached. + CachedFiles []string + + // UncachedBufs records in bytes.Buffer instances the data written + // via Handler.Uncached. + UncachedBufs []*bytes.Buffer + + // WriteErrors records the write errors received via Handler.Uncached. + WriteErrors []error +} + +// Reset resets the handler sinks. +func (sh *SinkHandler) Reset() { + sh.mu.Lock() + defer sh.mu.Unlock() + sh.Errors = nil + sh.CachedFiles = nil + sh.UncachedBufs = nil + sh.WriteErrors = nil +} + +// NewSinkHandler returns a new SinkHandler. +func NewSinkHandler(log *slog.Logger) *SinkHandler { + h := &SinkHandler{log: log} + h.Cached = func(fp string) { + log.Info("Cached", lga.File, fp) + h.mu.Lock() + defer h.mu.Unlock() + h.CachedFiles = append(h.CachedFiles, fp) + } + + h.Uncached = func() ioz.WriteErrorCloser { + log.Info("Uncached") + h.mu.Lock() + defer h.mu.Unlock() + buf := &bytes.Buffer{} + h.UncachedBufs = append(h.UncachedBufs, buf) + return ioz.NewFuncWriteErrorCloser(ioz.WriteCloser(buf), func(err error) { + h.mu.Lock() + defer h.mu.Unlock() + h.WriteErrors = append(h.WriteErrors, err) + }) + } + + h.Error = func(err error) { + log.Info("Error", lga.Err, err) + h.mu.Lock() + defer h.mu.Unlock() + h.Errors = append(h.Errors, err) + } + return h +} diff --git a/libsq/source/files.go b/libsq/source/files.go index 23639d533..a91b81ded 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -322,32 +322,6 @@ func (fs *Files) Open(ctx context.Context, src *Source) (io.ReadCloser, error) { return fs.newReader(ctx, src.Location) } -// NewLock returns a new source.Lock instance. -func NewLock(src *Source, pidfile string) (Lock, error) { - lf, err := lockfile.New(pidfile) - if err != nil { - return Lock{}, errz.Err(err) - } - - return Lock{ - Lockfile: lf, - src: src, - }, nil -} - -type Lock struct { - lockfile.Lockfile - src *Source -} - -func (l Lock) Source() *Source { - return l.src -} - -func (l Lock) String() string { - return l.src.Handle + ": " + string(l.Lockfile) -} - // CacheLockFor returns the lock file for src's cache. func (fs *Files) CacheLockFor(src *Source) (lockfile.Lockfile, error) { cacheDir, err := fs.CacheDirFor(src) @@ -370,32 +344,6 @@ func (fs *Files) OpenFunc(src *Source) FileOpenFunc { } } -// ReadAll is a convenience method to read the bytes of a source. -// -// FIXME: Delete Files.ReadAll? -// -// Deprecated: Files.ReadAll is not in use. We can probably delete it. -func (fs *Files) ReadAll(ctx context.Context, src *Source) ([]byte, error) { - // fs.mu.Lock() - r, err := fs.newReader(ctx, src.Location) - // fs.mu.Unlock() - if err != nil { - return nil, err - } - - var data []byte - data, err = io.ReadAll(r) - closeErr := r.Close() - if err != nil { - return nil, err - } - if closeErr != nil { - return nil, closeErr - } - - return data, nil -} - func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, error) { log := lg.FromContext(ctx).With(lga.Loc, loc) diff --git a/libsq/source/lock.go b/libsq/source/lock.go new file mode 100644 index 000000000..5ae34a42b --- /dev/null +++ b/libsq/source/lock.go @@ -0,0 +1,35 @@ +package source + +import ( + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" +) + +// NewLock returns a new source.Lock instance. +// +// REVISIT: We may not actually use source.Lock at all, and +// instead stick with ioz/lockfile.Lockfile. +func NewLock(src *Source, pidfile string) (Lock, error) { + lf, err := lockfile.New(pidfile) + if err != nil { + return Lock{}, errz.Err(err) + } + + return Lock{ + Lockfile: lf, + src: src, + }, nil +} + +type Lock struct { + lockfile.Lockfile + src *Source +} + +func (l Lock) Source() *Source { + return l.src +} + +func (l Lock) String() string { + return l.src.Handle + ": " + string(l.Lockfile) +} diff --git a/testh/tu/tutil.go b/testh/tu/tu.go similarity index 97% rename from testh/tu/tutil.go rename to testh/tu/tu.go index f241a0d43..d171ee0b7 100644 --- a/testh/tu/tutil.go +++ b/testh/tu/tu.go @@ -404,7 +404,7 @@ func ReadFileToString(t testing.TB, name string) string { // of open file handles for the current process via shelling out // to lsof. This function is skipped on Windows. // If log is true, the output of lsof is logged. -func OpenFileCount(t *testing.T, log bool) int { +func OpenFileCount(t testing.TB, log bool) int { count, out := doOpenFileCount(t) msg := fmt.Sprintf("Open files for [%d]: %d", os.Getpid(), count) if log { @@ -414,7 +414,7 @@ func OpenFileCount(t *testing.T, log bool) int { return count } -func doOpenFileCount(t *testing.T) (count int, out string) { +func doOpenFileCount(t testing.TB) (count int, out string) { SkipWindows(t, "OpenFileCount not implemented on Windows") b, err := exec.Command("/bin/sh", "-c", fmt.Sprintf("lsof -p %v", os.Getpid())).Output() require.NoError(t, err) @@ -426,7 +426,7 @@ func doOpenFileCount(t *testing.T) (count int, out string) { // DiffOpenFileCount is a debugging function that compares the // open file count at the start of the test with the count at // the end of the test (via t.Cleanup). This function is skipped on Windows. -func DiffOpenFileCount(t *testing.T, log bool) { +func DiffOpenFileCount(t testing.TB, log bool) { openingCount, openingOut := doOpenFileCount(t) if log { t.Logf("START: Open files for [%d]: %d\n\n%s", os.Getpid(), openingCount, openingOut) @@ -443,3 +443,9 @@ func DiffOpenFileCount(t *testing.T, log bool) { } }) } + +// UseProxy sets HTTP_PROXY and HTTPS_PROXY to localhost:9001. +func UseProxy(t testing.TB) { + t.Setenv("HTTP_PROXY", "http://localhost:9001") + t.Setenv("HTTPS_PROXY", "http://localhost:9001") +} diff --git a/testh/tu/tutil_test.go b/testh/tu/tu_test.go similarity index 100% rename from testh/tu/tutil_test.go rename to testh/tu/tu_test.go From 7e7741fcac4f087b58f0afd4cef4b26a913ce3cd Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 23:29:22 -0700 Subject: [PATCH 114/195] Refactoring httpz.NewClient --- cli/cmd_x.go | 5 +- go.mod | 1 + go.sum | 3 +- libsq/core/ioz/download/cache.go | 4 +- libsq/core/ioz/download/download.go | 55 ++++++---- libsq/core/ioz/download/download_test.go | 4 +- libsq/core/ioz/httpz/httpz.go | 106 ++++++++++++++++--- libsq/core/ioz/httpz/httpz_test.go | 61 +---------- libsq/core/ioz/httpz/opts.go | 123 +++++++++++++++++++++++ 9 files changed, 261 insertions(+), 101 deletions(-) create mode 100644 libsq/core/ioz/httpz/opts.go diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 8f2cc8b54..1a3eec511 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -157,12 +157,13 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { } sum := checksum.Sum([]byte(u.String())) - cacheDir, err := ru.Files.CacheDirFor(&source.Source{Handle: "@download_" + sum}) + fakeSrc := &source.Source{Handle: "@download_" + sum} + cacheDir, err := ru.Files.CacheDirFor(fakeSrc) if err != nil { return err } - dl, err := download.New(httpz.NewDefaultClient(), u.String(), cacheDir) + dl, err := download.New(fakeSrc.Handle, httpz.NewDefaultClient2(), u.String(), cacheDir) if err != nil { return err } diff --git a/go.mod b/go.mod index d69ab6dc7..d36739d6f 100644 --- a/go.mod +++ b/go.mod @@ -97,5 +97,6 @@ require ( golang.org/x/crypto v0.16.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 6866c715d..8e21d1b2a 100644 --- a/go.sum +++ b/go.sum @@ -270,8 +270,9 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index 49034d10c..ab0229624 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -140,8 +140,8 @@ func (c *cache) checksum(req *http.Request) (sum checksum.Checksum, ok bool) { return sum, ok } -// Clear deletes the cache entries from disk. -func (c *cache) Clear(ctx context.Context) error { +// clear deletes the cache entries from disk. +func (c *cache) clear(ctx context.Context) error { if c == nil { return nil } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 6ff6935a0..6373222c4 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -16,6 +16,7 @@ import ( "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/progress" "io" "net/http" "net/url" @@ -81,13 +82,16 @@ func OptDisableCaching(disable bool) Opt { type Download struct { // FIXME: Does Download need a sync.Mutex? + // name is a user-friendly name, such as a source handle like @data. + name string + // url is the URL of the download. It is parsed in download.New, // thus is guaranteed to be valid. url string c *http.Client - respCache *cache + cache *cache // markCachedResponses, if true, indicates that responses returned from the // cache will be given an extra header, X-From-cache. @@ -97,14 +101,16 @@ type Download struct { } // New returns a new Download for url that writes to cacheDir. -// If c is nil, http.DefaultClient is used. -func New(c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) { +// Name is a user-friendly name, such as a source handle like @data. +// The name may show up in logs, or progress indicators etc. +// If c is nil, httpz.NewDefaultClient is used. +func New(name string, c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) { _, err := url.ParseRequestURI(dlURL) if err != nil { return nil, errz.Wrap(err, "invalid download URL") } if c == nil { - c = http.DefaultClient + c = httpz.NewDefaultClient2() } if cacheDir, err = filepath.Abs(cacheDir); err != nil { @@ -112,6 +118,7 @@ func New(c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) } t := &Download{ + name: name, c: c, url: dlURL, markCachedResponses: true, @@ -122,7 +129,7 @@ func New(c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) } if !t.disableCaching { - t.respCache = &cache{dir: cacheDir} + t.cache = &cache{dir: cacheDir} } return t, nil @@ -138,7 +145,7 @@ func (dl *Download) get(req *http.Request, h Handler) { ctx := req.Context() log := lg.FromContext(ctx) log.Debug("Get download", lga.URL, dl.url) - _, fpBody, _ := dl.respCache.paths(req) + _, fpBody, _ := dl.cache.paths(req) state := dl.state(req) if state == Fresh { @@ -150,10 +157,10 @@ func (dl *Download) get(req *http.Request, h Handler) { cacheable := dl.isCacheable(req) var cachedResp *http.Response if cacheable { - cachedResp, err = dl.respCache.get(req.Context(), req) + cachedResp, err = dl.cache.get(req.Context(), req) } else { // Need to invalidate an existing value - if err = dl.respCache.Clear(req.Context()); err != nil { + if err = dl.cache.clear(req.Context()); err != nil { h.Error(err) return } @@ -211,7 +218,7 @@ func (dl *Download) get(req *http.Request, h Handler) { return } else { if err != nil || resp.StatusCode != http.StatusOK { - lg.WarnIfError(log, msgDeleteCache, dl.respCache.Clear(req.Context())) + lg.WarnIfError(log, msgDeleteCache, dl.cache.clear(req.Context())) } if err != nil { h.Error(err) @@ -243,8 +250,8 @@ func (dl *Download) get(req *http.Request, h Handler) { if resp == cachedResp { lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - if err = dl.respCache.write(ctx, resp, true, nil); err != nil { - log.Error("Failed to update cache header", lga.Dir, dl.respCache.dir, lga.Err, err) + if err = dl.cache.write(ctx, resp, true, nil); err != nil { + log.Error("Failed to update cache header", lga.Dir, dl.cache.dir, lga.Err, err) h.Error(err) return } @@ -260,13 +267,13 @@ func (dl *Download) get(req *http.Request, h Handler) { } defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - if err = dl.respCache.write(req.Context(), resp, false, destWrtr); err != nil { - log.Error("Failed to write download cache", lga.Dir, dl.respCache.dir, lga.Err, err) + if err = dl.cache.write(req.Context(), resp, false, destWrtr); err != nil { + log.Error("Failed to write download cache", lga.Dir, dl.cache.dir, lga.Err, err) //destWrtr.Error(err) } return } else { - lg.WarnIfError(log, "Delete resp cache", dl.respCache.Clear(req.Context())) + lg.WarnIfError(log, "Delete resp cache", dl.cache.clear(req.Context())) } // It's not cacheable, so we need to write it to the destWrtr, @@ -294,7 +301,12 @@ func (dl *Download) get(req *http.Request, h Handler) { // do executes the request. func (dl *Download) do(req *http.Request) (*http.Response, error) { - return dl.c.Do(req) + resp, err := dl.c.Do(req) + if err == nil && resp.Body != nil { + r := progress.NewReader(req.Context(), dl.name+": download", resp.ContentLength, resp.Body) + resp.Body = r.(io.ReadCloser) + } + return resp, err } // mustRequest creates a new request from dl.url. The url has already been @@ -306,14 +318,13 @@ func (dl *Download) mustRequest(ctx context.Context) *http.Request { panic(err) return nil } - return req } // Clear deletes the cache. func (dl *Download) Clear(ctx context.Context) error { - if dl.respCache != nil { - return dl.respCache.Clear(ctx) + if dl.cache != nil { + return dl.cache.clear(ctx) } return nil } @@ -331,11 +342,11 @@ func (dl *Download) state(req *http.Request) State { ctx := req.Context() log := lg.FromContext(ctx) - if !dl.respCache.exists(req) { + if !dl.cache.exists(req) { return Uncached } - fpHeader, _, _ := dl.respCache.paths(req) + fpHeader, _, _ := dl.cache.paths(req) f, err := os.Open(fpHeader) if err != nil { log.Error(msgCloseCacheHeaderFile, lga.File, fpHeader, lga.Err, err) @@ -355,13 +366,13 @@ func (dl *Download) state(req *http.Request) State { // Checksum returns the checksum of the cached download, if available. func (dl *Download) Checksum(ctx context.Context) (sum checksum.Checksum, ok bool) { - if dl.respCache == nil { + if dl.cache == nil { return "", false } req := dl.mustRequest(ctx) - _, _, fp := dl.respCache.paths(req) + _, _, fp := dl.cache.paths(req) if !ioz.FileAccessible(fp) { return "", false } diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index 1a9a6a386..132c2b0b1 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -77,7 +77,7 @@ func TestDownload_redirect(t *testing.T) { ctx := lg.NewContext(context.Background(), log.With("origin", "downloader")) loc := srvr.URL + "/redirect" - dl, err := download.New(httpz.NewDefaultClient(), loc, cacheDir) + dl, err := download.New(t.Name(), httpz.NewDefaultClient2(), loc, cacheDir) require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) h := download.NewSinkHandler(log.With("origin", "handler")) @@ -144,7 +144,7 @@ func TestDownload_New(t *testing.T) { require.NoError(t, err) t.Logf("cacheDir: %s", cacheDir) - dl, err := download.New(httpz.NewDefaultClient(), dlURL, cacheDir) + dl, err := download.New(t.Name(), httpz.NewDefaultClient2(), dlURL, cacheDir) require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) require.Equal(t, download.Uncached, dl.State(ctx)) diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index faff1a9c2..43b4ddfc0 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -26,7 +26,19 @@ import ( // NewDefaultClient returns a new HTTP client with default settings. func NewDefaultClient() *http.Client { - return NewClient(buildinfo.Get().UserAgent(), true, 0, 0) + return NewClient( + buildinfo.Get().UserAgent(), + true, + 0, + 0, + OptUserAgent(buildinfo.Get().UserAgent()), + ) +} // NewDefaultClient returns a new HTTP client with default settings. +func NewDefaultClient2() *http.Client { + return NewClient2( + OptInsecureSkipVerify(true), + OptUserAgent(buildinfo.Get().UserAgent()), + ) } // NewClient returns a new HTTP client. If userAgent is non-empty, the @@ -38,7 +50,7 @@ func NewDefaultClient() *http.Client { // to read. If bodyTimeout > 0, it is applied to the total lifecycle of // the request and response, including reading the response body. func NewClient(userAgent string, insecureSkipVerify bool, - headerTimeout, bodyTimeout time.Duration, + headerTimeout, bodyTimeout time.Duration, tripFuncs ...TripFunc, ) *http.Client { c := *http.DefaultClient var tr *http.Transport @@ -60,24 +72,92 @@ func NewClient(userAgent string, insecureSkipVerify bool, tr.TLSClientConfig.InsecureSkipVerify = insecureSkipVerify c.Transport = tr - if userAgent != "" { - c.Transport = &userAgentRoundTripper{ - userAgent: userAgent, - rt: c.Transport, - } + for i := range tripFuncs { + c.Transport = RoundTrip(c.Transport, tripFuncs[i]) } + // + //if userAgent != "" { + // //c.Transport = UserAgent2(c.Transport, userAgent) + // + // //var funcs []TripFunc + // + // + // + // //c.Transport = RoundTrip(c.Transport, func(next http.RoundTripper, req *http.Request) (*http.Response, error) { + // // req.Header.Set("User-Agent", userAgent) + // // return next.RoundTrip(req) + // //}) + // + // c.Transport = &userAgentRoundTripper{ + // userAgent: userAgent, + // rt: c.Transport, + // } + //} + // + //c.Timeout = bodyTimeout + //if headerTimeout > 0 { + // c.Transport = &headerTimeoutRoundTripper{ + // headerTimeout: headerTimeout, + // rt: c.Transport, + // } + //} - c.Timeout = bodyTimeout - if headerTimeout > 0 { - c.Transport = &headerTimeoutRoundTripper{ - headerTimeout: headerTimeout, - rt: c.Transport, - } + return &c +} + +// NewClient2 returns a new HTTP client configured with opts. +func NewClient2(opts ...Opt) *http.Client { + c := *http.DefaultClient + var tr *http.Transport + if c.Transport == nil { + tr = (http.DefaultTransport.(*http.Transport)).Clone() + } else { + tr = (c.Transport.(*http.Transport)).Clone() } + DefaultTLSVersion.apply(tr) + for _, opt := range opts { + opt.apply(tr) + } + + c.Transport = tr + for i := range opts { + if tf, ok := opts[i].(TripFunc); ok { + c.Transport = RoundTrip(c.Transport, tf) + } + } return &c } +var _ Opt = (*TripFunc)(nil) + +// TripFunc is a function that implements http.RoundTripper. +// It is commonly used with RoundTrip to decorate an existing http.RoundTripper. +type TripFunc func(next http.RoundTripper, req *http.Request) (*http.Response, error) + +func (tf TripFunc) apply(tr *http.Transport) {} + +// RoundTrip adapts a TripFunc to http.RoundTripper. +func RoundTrip(next http.RoundTripper, fn TripFunc) http.RoundTripper { + return roundTripFunc(func(req *http.Request) (*http.Response, error) { + return fn(next, req) + }) +} + +// NopTripFunc is a TripFunc that does nothing. +func NopTripFunc(next http.RoundTripper, req *http.Request) (*http.Response, error) { + return next.RoundTrip(req) +} + +// roundTripFunc is an adapter to allow use of functions as http.RoundTripper. +// It works with TripFunc and RoundTrip. +type roundTripFunc func(*http.Request) (*http.Response, error) + +// RoundTrip implements http.RoundTripper. +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + // userAgentRoundTripper applies a User-Agent header to each request. type userAgentRoundTripper struct { userAgent string diff --git a/libsq/core/ioz/httpz/httpz_test.go b/libsq/core/ioz/httpz/httpz_test.go index 1e0cd2df2..8540cec40 100644 --- a/libsq/core/ioz/httpz/httpz_test.go +++ b/libsq/core/ioz/httpz/httpz_test.go @@ -9,12 +9,10 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestNewHTTPClient_headerTimeout(t *testing.T) { +func TestOptHeaderTimeout(t *testing.T) { t.Parallel() const ( headerTimeout = time.Second * 2 @@ -42,7 +40,7 @@ func TestNewHTTPClient_headerTimeout(t *testing.T) { ctxFn: func(t *testing.T) context.Context { return context.Background() }, - c: httpz.NewClient("", false, headerTimeout, 0), + c: httpz.NewClient2(httpz.OptHeaderTimeout(headerTimeout)), wantErr: false, }, } @@ -93,58 +91,3 @@ func TestNewHTTPClient_headerTimeout(t *testing.T) { }) } } - -func TestTimeout1(t *testing.T) { - const urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" - - const urlActorCSV = "https://sq.io/testdata/actor.csv" - const respTimeout = time.Second * 2 - const lines = 10 - const wantLen = lines * 2 - slowServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - for i := 0; i < lines; i++ { - select { - case <-r.Context().Done(): - t.Logf("Server exiting due to: %v", r.Context().Err()) - return - default: - } - _, _ = io.WriteString(w, string(rune('A'+i))+"\n") - w.(http.Flusher).Flush() - time.Sleep(time.Second) - } - })) - t.Cleanup(slowServer.Close) - - ctx, cancelFn := context.WithTimeout(context.Background(), respTimeout) - defer cancelFn() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, slowServer.URL, nil) - require.NoError(t, err) - - resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) - require.Equal(t, http.StatusOK, resp.StatusCode) - - // cancelFn() - time.Sleep(time.Second * 3) - - select { - case <-ctx.Done(): - t.Logf("ctx is done: %v", ctx.Err()) - default: - t.Logf("ctx is not done") - cancelFn() - } - - // cancelFn() - b, err := io.ReadAll(resp.Body) - t.Logf("err: %T: %v", err, err) - t.Logf("len(b): %d", len(b)) - t.Logf("b:\n\n%s\n\n", b) - assert.Error(t, err) - // require.Nil(t, b) - _ = b - // require.Len(t, b, 0) - // require.Len(t, b, 7641) -} diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go new file mode 100644 index 000000000..d5cc5e709 --- /dev/null +++ b/libsq/core/ioz/httpz/opts.go @@ -0,0 +1,123 @@ +package httpz + +import ( + "context" + "crypto/tls" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" + "net/http" + "time" +) + +// Opt is an option that can be passed to [NewClient2] to +// configure the client. +type Opt interface { + apply(*http.Transport) +} + +var _ Opt = (*OptInsecureSkipVerify)(nil) + +// OptInsecureSkipVerify is an Opt that can be passed to NewClient that, +// when true, disables TLS verification. +type OptInsecureSkipVerify bool + +func (b OptInsecureSkipVerify) apply(tr *http.Transport) { + tr.TLSClientConfig.InsecureSkipVerify = bool(b) +} + +var _ Opt = (*minTLSVersion)(nil) + +type minTLSVersion uint16 + +func (v minTLSVersion) apply(tr *http.Transport) { + if tr.TLSClientConfig == nil { + // We allow tls.VersionTLS10, even though it's not considered + // secure these days. Ultimately this could become a config + // option. + tr.TLSClientConfig = &tls.Config{MinVersion: uint16(v)} //nolint:gosec + } else { + tr.TLSClientConfig = tr.TLSClientConfig.Clone() + tr.TLSClientConfig.MinVersion = uint16(v) //nolint:gosec + } +} + +// DefaultTLSVersion is the default minimum TLS version used by [NewClient2]. +var DefaultTLSVersion = minTLSVersion(tls.VersionTLS10) + +// OptUserAgent is passed to [NewClient2] to set the User-Agent header. +func OptUserAgent(ua string) TripFunc { + return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", ua) + return next.RoundTrip(req) + } +} + +// OptRequestTimeout is passed to [NewClient2] to set the total request timeout. +// If timeout is zero, this is a no-op. +// +// Contrast with [OptHeaderTimeout]. +func OptRequestTimeout(timeout time.Duration) TripFunc { + if timeout <= 0 { + return NopTripFunc + } + return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { + ctx, cancelFn := context.WithTimeoutCause(req.Context(), timeout, + errz.Errorf("http request not completed in %s timeout", timeout)) + defer cancelFn() + req = req.WithContext(ctx) + return next.RoundTrip(req) + } +} + +// OptHeaderTimeout is passed to [NewClient2] to set a timeout for just +// getting the initial response headers. This is useful if you expect +// a response within, say, 5 seconds, but you expect the body to take longer +// to read. If bodyTimeout > 0, it is applied to the total lifecycle of +// the request and response, including reading the response body. +// If timeout <= zero, this is a no-op. +// +// Contrast with [OptRequestTimeout]. +func OptHeaderTimeout(timeout time.Duration) TripFunc { + if timeout <= 0 { + return NopTripFunc + } + return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { + + timerCancelCh := make(chan struct{}) + ctx, cancelFn := context.WithCancelCause(req.Context()) + go func() { + t := time.NewTimer(timeout) + defer t.Stop() + select { + case <-ctx.Done(): + case <-t.C: + cancelFn(errz.Errorf("http response not received by %s timeout", + timeout)) + case <-timerCancelCh: + // Stop the timer goroutine. + } + }() + + resp, err := next.RoundTrip(req.WithContext(ctx)) + close(timerCancelCh) + + // Don't leak resources; ensure that cancelFn is eventually called. + switch { + case err != nil: + + // It's possible that cancelFn has already been called by the + // timer goroutine, but we call it again just in case. + cancelFn(err) + case resp != nil && resp.Body != nil: + + // Wrap resp.Body with a ReadCloserNotifier, so that cancelFn + // is called when the body is closed. + resp.Body = ioz.ReadCloserNotifier(resp.Body, cancelFn) + default: + // Not sure if this can actually happen, but just in case. + cancelFn(context.Canceled) + } + + return resp, err + } +} From 57b5806d779ebc2f3b2f9b5ab9288feb08b2df45 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 13 Dec 2023 23:38:46 -0700 Subject: [PATCH 115/195] Refactoring httpz.NewClient: done --- cli/cmd_x.go | 9 +- libsq/core/ioz/checksum/checksum_test.go | 3 +- libsq/core/ioz/download/cache.go | 21 +-- libsq/core/ioz/download/download.go | 15 ++- libsq/core/ioz/download/download_test.go | 17 +-- libsq/core/ioz/download/handler.go | 5 +- libsq/core/ioz/httpz/httpz.go | 161 ++--------------------- libsq/core/ioz/httpz/httpz_test.go | 5 +- libsq/core/ioz/httpz/opts.go | 16 +-- libsq/source/download.go | 3 +- 10 files changed, 64 insertions(+), 191 deletions(-) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 1a3eec511..43975d0a9 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -3,13 +3,14 @@ package cli import ( "bufio" "fmt" + "net/url" + "os" + "time" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/source" - "net/url" - "os" - "time" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/progress" @@ -163,7 +164,7 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { return err } - dl, err := download.New(fakeSrc.Handle, httpz.NewDefaultClient2(), u.String(), cacheDir) + dl, err := download.New(fakeSrc.Handle, httpz.NewDefaultClient(), u.String(), cacheDir) if err != nil { return err } diff --git a/libsq/core/ioz/checksum/checksum_test.go b/libsq/core/ioz/checksum/checksum_test.go index f5eaf193c..0a46310d5 100644 --- a/libsq/core/ioz/checksum/checksum_test.go +++ b/libsq/core/ioz/checksum/checksum_test.go @@ -1,9 +1,10 @@ package checksum_test import ( - "github.com/stretchr/testify/require" "testing" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/assert" "github.com/neilotoole/sq/libsq/core/ioz/checksum" diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index ab0229624..237d94bfb 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -4,10 +4,6 @@ import ( "bufio" "bytes" "context" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" - "github.com/neilotoole/sq/libsq/core/lg/lgm" "io" "net/http" "net/http/httputil" @@ -15,14 +11,21 @@ import ( "path/filepath" "sync" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" ) -const msgCloseCacheHeaderFile = "Close cached response header file" -const msgCloseCacheBodyFile = "Close cached response body file" +const ( + msgCloseCacheHeaderFile = "Close cached response header file" + msgCloseCacheBodyFile = "Close cached response body file" +) // cache is a cache for a individual download. The cached response is // stored in two files, one for the header and one for the body, with @@ -175,7 +178,8 @@ const msgDeleteCache = "Delete HTTP response cache" // A checksum file, computed from the body file, is also written to disk. The // response body is always closed. func (c *cache) write(ctx context.Context, resp *http.Response, - headerOnly bool, copyWrtr ioz.WriteErrorCloser) error { + headerOnly bool, copyWrtr ioz.WriteErrorCloser, +) error { c.mu.Lock() defer c.mu.Unlock() @@ -183,7 +187,8 @@ func (c *cache) write(ctx context.Context, resp *http.Response, } func (c *cache) doWrite(ctx context.Context, resp *http.Response, - headerOnly bool, copyWrtr ioz.WriteErrorCloser) (err error) { + headerOnly bool, copyWrtr ioz.WriteErrorCloser, +) (err error) { log := lg.FromContext(ctx) defer func() { diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 6373222c4..797d93967 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -11,17 +11,18 @@ package download import ( "bufio" "context" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/progress" - "io" - "net/http" - "net/url" - "os" - "path/filepath" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" @@ -110,7 +111,7 @@ func New(name string, c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Dow return nil, errz.Wrap(err, "invalid download URL") } if c == nil { - c = httpz.NewDefaultClient2() + c = httpz.NewDefaultClient() } if cacheDir, err = filepath.Abs(cacheDir); err != nil { @@ -269,7 +270,7 @@ func (dl *Download) get(req *http.Request, h Handler) { defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) if err = dl.cache.write(req.Context(), resp, false, destWrtr); err != nil { log.Error("Failed to write download cache", lga.Dir, dl.cache.dir, lga.Err, err) - //destWrtr.Error(err) + // destWrtr.Error(err) } return } else { diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index 132c2b0b1..30c7bcdf2 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -2,10 +2,6 @@ package download_test import ( "context" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/testh/tu" - "github.com/stretchr/testify/assert" "net/http" "net/http/httptest" "os" @@ -13,6 +9,11 @@ import ( "testing" "time" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/testh/tu" + "github.com/stretchr/testify/assert" + "github.com/neilotoole/slogt" "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/sq/libsq/core/lg" @@ -27,9 +28,9 @@ const ( func TestDownload_redirect(t *testing.T) { const hello = `Hello World!` - var serveBody = hello + serveBody := hello lastModified := time.Now().UTC() - //cacheDir := t.TempDir() + // cacheDir := t.TempDir() // FIXME: switch back to temp dir cacheDir := filepath.Join("testdata", "download", tu.Name(t.Name())) @@ -77,7 +78,7 @@ func TestDownload_redirect(t *testing.T) { ctx := lg.NewContext(context.Background(), log.With("origin", "downloader")) loc := srvr.URL + "/redirect" - dl, err := download.New(t.Name(), httpz.NewDefaultClient2(), loc, cacheDir) + dl, err := download.New(t.Name(), httpz.NewDefaultClient(), loc, cacheDir) require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) h := download.NewSinkHandler(log.With("origin", "handler")) @@ -144,7 +145,7 @@ func TestDownload_New(t *testing.T) { require.NoError(t, err) t.Logf("cacheDir: %s", cacheDir) - dl, err := download.New(t.Name(), httpz.NewDefaultClient2(), dlURL, cacheDir) + dl, err := download.New(t.Name(), httpz.NewDefaultClient(), dlURL, cacheDir) require.NoError(t, err) require.NoError(t, dl.Clear(ctx)) require.Equal(t, download.Uncached, dl.State(ctx)) diff --git a/libsq/core/ioz/download/handler.go b/libsq/core/ioz/download/handler.go index 50444f76d..265e6df63 100644 --- a/libsq/core/ioz/download/handler.go +++ b/libsq/core/ioz/download/handler.go @@ -2,10 +2,11 @@ package download import ( "bytes" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/lg/lga" "log/slog" "sync" + + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/lg/lga" ) // Handler is a callback invoked by Download.Get. Exactly one of the diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index 43b4ddfc0..e5176fd3a 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -1,14 +1,16 @@ // Package httpz provides functionality supplemental to stdlib http. // Indeed, some of the functions are copied verbatim from stdlib. +// The jumping-off point is [httpz.NewClient]. +// +// Design note: this package contains generally fairly straightforward HTTP +// functionality, but the Opt / TripFunc config mechanism is a bit +// experimental. And probably tries to be a bit too clever. It may change. package httpz import ( "bufio" - "context" - "crypto/tls" "fmt" "github.com/neilotoole/sq/cli/buildinfo" - "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/stringz" "io" "log/slog" @@ -19,94 +21,18 @@ import ( "path/filepath" "strconv" "strings" - "time" - - "github.com/neilotoole/sq/libsq/core/errz" ) -// NewDefaultClient returns a new HTTP client with default settings. +// NewDefaultClient invokes NewClient with default settings. func NewDefaultClient() *http.Client { return NewClient( - buildinfo.Get().UserAgent(), - true, - 0, - 0, + OptInsecureSkipVerify(false), OptUserAgent(buildinfo.Get().UserAgent()), ) -} // NewDefaultClient returns a new HTTP client with default settings. -func NewDefaultClient2() *http.Client { - return NewClient2( - OptInsecureSkipVerify(true), - OptUserAgent(buildinfo.Get().UserAgent()), - ) -} - -// NewClient returns a new HTTP client. If userAgent is non-empty, the -// "User-Agent" header is applied to each request. If insecureSkipVerify is -// true, the client will skip TLS verification. If headerTimeout > 0, a -// timeout is applied to receiving the HTTP response, but that timeout is -// not applied to reading the response body. This is useful if you expect -// a response within, say, 5 seconds, but you expect the body to take longer -// to read. If bodyTimeout > 0, it is applied to the total lifecycle of -// the request and response, including reading the response body. -func NewClient(userAgent string, insecureSkipVerify bool, - headerTimeout, bodyTimeout time.Duration, tripFuncs ...TripFunc, -) *http.Client { - c := *http.DefaultClient - var tr *http.Transport - if c.Transport == nil { - tr = (http.DefaultTransport.(*http.Transport)).Clone() - } else { - tr = (c.Transport.(*http.Transport)).Clone() - } - - if tr.TLSClientConfig == nil { - // We allow tls.VersionTLS10, even though it's not considered - // secure these days. Ultimately this could become a config - // option. - tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS10} //nolint:gosec - } else { - tr.TLSClientConfig = tr.TLSClientConfig.Clone() - tr.TLSClientConfig.MinVersion = tls.VersionTLS10 //nolint:gosec - } - - tr.TLSClientConfig.InsecureSkipVerify = insecureSkipVerify - c.Transport = tr - for i := range tripFuncs { - c.Transport = RoundTrip(c.Transport, tripFuncs[i]) - } - // - //if userAgent != "" { - // //c.Transport = UserAgent2(c.Transport, userAgent) - // - // //var funcs []TripFunc - // - // - // - // //c.Transport = RoundTrip(c.Transport, func(next http.RoundTripper, req *http.Request) (*http.Response, error) { - // // req.Header.Set("User-Agent", userAgent) - // // return next.RoundTrip(req) - // //}) - // - // c.Transport = &userAgentRoundTripper{ - // userAgent: userAgent, - // rt: c.Transport, - // } - //} - // - //c.Timeout = bodyTimeout - //if headerTimeout > 0 { - // c.Transport = &headerTimeoutRoundTripper{ - // headerTimeout: headerTimeout, - // rt: c.Transport, - // } - //} - - return &c } -// NewClient2 returns a new HTTP client configured with opts. -func NewClient2(opts ...Opt) *http.Client { +// NewClient returns a new HTTP client configured with opts. +func NewClient(opts ...Opt) *http.Client { c := *http.DefaultClient var tr *http.Transport if c.Transport == nil { @@ -158,72 +84,7 @@ func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } -// userAgentRoundTripper applies a User-Agent header to each request. -type userAgentRoundTripper struct { - userAgent string - rt http.RoundTripper -} - -// RoundTrip implements http.RoundTripper. -func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - req.Header.Set("User-Agent", rt.userAgent) - return rt.rt.RoundTrip(req) -} - -// headerTimeoutRoundTripper applies headerTimeout to the return of the http -// response, but headerTimeout is not applied to reading the body of the -// response. This is useful if you expect a response within, say, 5 seconds, -// but you expect the body to take longer to read. -type headerTimeoutRoundTripper struct { - headerTimeout time.Duration - rt http.RoundTripper -} - -// RoundTrip implements http.RoundTripper. -func (rt *headerTimeoutRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if rt.headerTimeout <= 0 { - return rt.rt.RoundTrip(req) - } - - timerCancelCh := make(chan struct{}) - ctx, cancelFn := context.WithCancelCause(req.Context()) - go func() { - t := time.NewTimer(rt.headerTimeout) - defer t.Stop() - select { - case <-ctx.Done(): - case <-t.C: - cancelFn(errz.Errorf("http response not received by %s timeout", - rt.headerTimeout)) - case <-timerCancelCh: - // Stop the timer goroutine. - } - }() - - resp, err := rt.rt.RoundTrip(req.WithContext(ctx)) - close(timerCancelCh) - - // Don't leak resources; ensure that cancelFn is eventually called. - switch { - case err != nil: - - // It's possible that cancelFn has already been called by the - // timer goroutine, but we call it again just in case. - cancelFn(err) - case resp != nil && resp.Body != nil: - - // Wrap resp.Body with a ReadCloserNotifier, so that cancelFn - // is called when the body is closed. - resp.Body = ioz.ReadCloserNotifier(resp.Body, cancelFn) - default: - // Not sure if this can actually happen, but just in case. - cancelFn(context.Canceled) - } - - return resp, err -} - -// ResponseLogValue implements slog.Valuer for resp. +// ResponseLogValue implements slog.LogValuer for resp. func ResponseLogValue(resp *http.Response) slog.Value { if resp == nil { return slog.Value{} @@ -252,7 +113,7 @@ func ResponseLogValue(resp *http.Response) slog.Value { return slog.GroupValue(attrs...) } -// RequestLogValue implements slog.Valuer for req. +// RequestLogValue implements slog.LogValuer for req. func RequestLogValue(req *http.Request) slog.Value { if req == nil { return slog.Value{} diff --git a/libsq/core/ioz/httpz/httpz_test.go b/libsq/core/ioz/httpz/httpz_test.go index 8540cec40..b0fc2d61f 100644 --- a/libsq/core/ioz/httpz/httpz_test.go +++ b/libsq/core/ioz/httpz/httpz_test.go @@ -2,13 +2,14 @@ package httpz_test import ( "context" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" "io" "net/http" "net/http/httptest" "testing" "time" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/stretchr/testify/require" ) @@ -40,7 +41,7 @@ func TestOptHeaderTimeout(t *testing.T) { ctxFn: func(t *testing.T) context.Context { return context.Background() }, - c: httpz.NewClient2(httpz.OptHeaderTimeout(headerTimeout)), + c: httpz.NewClient(httpz.OptHeaderTimeout(headerTimeout)), wantErr: false, }, } diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index d5cc5e709..209d0da3c 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -3,13 +3,14 @@ package httpz import ( "context" "crypto/tls" - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" "net/http" "time" + + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" ) -// Opt is an option that can be passed to [NewClient2] to +// Opt is an option that can be passed to [NewClient] to // configure the client. type Opt interface { apply(*http.Transport) @@ -41,10 +42,10 @@ func (v minTLSVersion) apply(tr *http.Transport) { } } -// DefaultTLSVersion is the default minimum TLS version used by [NewClient2]. +// DefaultTLSVersion is the default minimum TLS version used by [NewClient]. var DefaultTLSVersion = minTLSVersion(tls.VersionTLS10) -// OptUserAgent is passed to [NewClient2] to set the User-Agent header. +// OptUserAgent is passed to [NewClient] to set the User-Agent header. func OptUserAgent(ua string) TripFunc { return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { req.Header.Set("User-Agent", ua) @@ -52,7 +53,7 @@ func OptUserAgent(ua string) TripFunc { } } -// OptRequestTimeout is passed to [NewClient2] to set the total request timeout. +// OptRequestTimeout is passed to [NewClient] to set the total request timeout. // If timeout is zero, this is a no-op. // // Contrast with [OptHeaderTimeout]. @@ -69,7 +70,7 @@ func OptRequestTimeout(timeout time.Duration) TripFunc { } } -// OptHeaderTimeout is passed to [NewClient2] to set a timeout for just +// OptHeaderTimeout is passed to [NewClient] to set a timeout for just // getting the initial response headers. This is useful if you expect // a response within, say, 5 seconds, but you expect the body to take longer // to read. If bodyTimeout > 0, it is applied to the total lifecycle of @@ -82,7 +83,6 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { return NopTripFunc } return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { - timerCancelCh := make(chan struct{}) ctx, cancelFn := context.WithCancelCause(req.Context()) go func() { diff --git a/libsq/source/download.go b/libsq/source/download.go index 95a6730bf..faf8955da 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -3,7 +3,6 @@ package source import ( "bytes" "context" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" "io" "log/slog" "net/http" @@ -14,6 +13,8 @@ import ( "sync" "time" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "golang.org/x/exp/maps" "github.com/neilotoole/sq/libsq/core/errz" From c923242a8a2f4aa68e7c4fffad1205a73a17e222 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 14 Dec 2023 07:34:24 -0700 Subject: [PATCH 116/195] Test httpz opts --- cli/cmd_x.go | 10 ++- cli/error.go | 2 +- cli/output.go | 3 +- cli/run/run.go | 3 +- libsq/core/ioz/checksum/checksum_test.go | 3 +- libsq/core/ioz/download/cache.go | 7 +- libsq/core/ioz/download/download.go | 7 +- libsq/core/ioz/download/download_test.go | 39 ++++++++++-- libsq/core/ioz/httpz/httpz.go | 12 ++-- libsq/core/ioz/httpz/httpz_test.go | 81 ++++++++++++++++++++++-- libsq/core/ioz/httpz/opts.go | 30 ++++++--- libsq/core/ioz/ioz.go | 3 +- libsq/core/progress/bars.go | 6 -- libsq/driver/grips.go | 3 +- libsq/source/download.go | 3 +- libsq/source/download_test.go | 3 +- testh/tu/tu.go | 11 ++++ 17 files changed, 170 insertions(+), 56 deletions(-) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 43975d0a9..944392f44 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -7,17 +7,15 @@ import ( "os" "time" + "github.com/spf13/cobra" + + "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/sq/libsq/core/ioz/httpz" - "github.com/neilotoole/sq/libsq/source" - "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/progress" - - "github.com/spf13/cobra" - - "github.com/neilotoole/sq/cli/run" + "github.com/neilotoole/sq/libsq/source" ) // newXCmd returns the root "x" command, which is the container diff --git a/cli/error.go b/cli/error.go index a8db0a173..504b09512 100644 --- a/cli/error.go +++ b/cli/error.go @@ -6,7 +6,6 @@ import ( "fmt" "os" - "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -14,6 +13,7 @@ import ( "github.com/neilotoole/sq/cli/output/format" "github.com/neilotoole/sq/cli/output/jsonw" "github.com/neilotoole/sq/cli/run" + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" diff --git a/cli/output.go b/cli/output.go index 2692a4cf2..51b9b2aaa 100644 --- a/cli/output.go +++ b/cli/output.go @@ -7,8 +7,6 @@ import ( "strings" "time" - "github.com/neilotoole/sq/libsq/core/cleanup" - "github.com/fatih/color" colorable "github.com/mattn/go-colorable" wordwrap "github.com/mitchellh/go-wordwrap" @@ -26,6 +24,7 @@ import ( "github.com/neilotoole/sq/cli/output/xlsxw" "github.com/neilotoole/sq/cli/output/xmlw" "github.com/neilotoole/sq/cli/output/yamlw" + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" diff --git a/cli/run/run.go b/cli/run/run.go index 02eac2839..c2a68545b 100644 --- a/cli/run/run.go +++ b/cli/run/run.go @@ -7,8 +7,6 @@ import ( "io" "os" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/config" @@ -16,6 +14,7 @@ import ( "github.com/neilotoole/sq/libsq" "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" diff --git a/libsq/core/ioz/checksum/checksum_test.go b/libsq/core/ioz/checksum/checksum_test.go index 0a46310d5..c5a3fdfd5 100644 --- a/libsq/core/ioz/checksum/checksum_test.go +++ b/libsq/core/ioz/checksum/checksum_test.go @@ -3,9 +3,8 @@ package checksum_test import ( "testing" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/core/ioz/checksum" ) diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index 237d94bfb..b48b5c7da 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -11,15 +11,14 @@ import ( "path/filepath" "sync" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/ioz/httpz" - "github.com/neilotoole/sq/libsq/core/lg/lgm" - - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" ) const ( diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 797d93967..de9273c8e 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -20,13 +20,12 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" - "github.com/neilotoole/sq/libsq/core/lg/lgm" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/progress" ) // State is an enumeration of caching states based on the cache-control diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index 30c7bcdf2..5dc4c6a06 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -6,18 +6,20 @@ import ( "net/http/httptest" "os" "path/filepath" + "strconv" "testing" "time" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/testh/tu" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/ioz/download" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg" - "github.com/stretchr/testify/require" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/testh/tu" ) const ( @@ -26,6 +28,35 @@ const ( sizeActorCSV = int64(7641) ) +func TestSlowHeaderServer(t *testing.T) { + const hello = `Hello World!` + var srvr *httptest.Server + serverDelay := time.Second * 200 + srvr = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + t.Log("Server request context done") + return + case <-time.After(serverDelay): + } + + w.Header().Set("Content-Type", "text/plain") + w.Header().Set("Content-Length", strconv.Itoa(len(hello))) + _, err := w.Write([]byte(hello)) + assert.NoError(t, err) + })) + t.Cleanup(srvr.Close) + + clientHeaderTimeout := time.Second * 2 + c := httpz.NewClient(httpz.OptHeaderTimeout(clientHeaderTimeout)) + req, err := http.NewRequest(http.MethodGet, srvr.URL, nil) + require.NoError(t, err) + resp, err := c.Do(req) + require.Error(t, err) + require.Nil(t, resp) + t.Log(err) +} + func TestDownload_redirect(t *testing.T) { const hello = `Hello World!` serveBody := hello diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index e5176fd3a..2a90c8d95 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -10,8 +10,6 @@ package httpz import ( "bufio" "fmt" - "github.com/neilotoole/sq/cli/buildinfo" - "github.com/neilotoole/sq/libsq/core/stringz" "io" "log/slog" "mime" @@ -21,6 +19,10 @@ import ( "path/filepath" "strconv" "strings" + "time" + + "github.com/neilotoole/sq/cli/buildinfo" + "github.com/neilotoole/sq/libsq/core/stringz" ) // NewDefaultClient invokes NewClient with default settings. @@ -28,12 +30,14 @@ func NewDefaultClient() *http.Client { return NewClient( OptInsecureSkipVerify(false), OptUserAgent(buildinfo.Get().UserAgent()), + OptHeaderTimeout(time.Second*5), ) } // NewClient returns a new HTTP client configured with opts. func NewClient(opts ...Opt) *http.Client { c := *http.DefaultClient + c.Timeout = 0 var tr *http.Transport if c.Transport == nil { tr = (http.DefaultTransport.(*http.Transport)).Clone() @@ -96,7 +100,7 @@ func ResponseLogValue(resp *http.Response) slog.Value { } h := resp.Header - for k, _ := range h { + for k := range h { vals := h.Values(k) if len(vals) == 1 { attrs = append(attrs, slog.String(k, vals[0])) @@ -132,7 +136,7 @@ func RequestLogValue(req *http.Request) slog.Value { } h := req.Header - for k, _ := range h { + for k := range h { vals := h.Values(k) if len(vals) == 1 { attrs = append(attrs, slog.String(k, vals[0])) diff --git a/libsq/core/ioz/httpz/httpz_test.go b/libsq/core/ioz/httpz/httpz_test.go index b0fc2d61f..16347f9f6 100644 --- a/libsq/core/ioz/httpz/httpz_test.go +++ b/libsq/core/ioz/httpz/httpz_test.go @@ -8,15 +8,86 @@ import ( "testing" "time" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" - + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/sq/testh/tu" ) -func TestOptHeaderTimeout(t *testing.T) { +func TestOptRequestTimeout(t *testing.T) { + t.Parallel() + const srvrBody = `Hello World!` + serverDelay := time.Millisecond * 200 + srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + t.Log("Server request context done") + return + case <-time.After(serverDelay): + } + _, err := w.Write([]byte(srvrBody)) + assert.NoError(t, err) + })) + t.Cleanup(srvr.Close) + + clientRequestTimeout := time.Millisecond * 100 + c := httpz.NewClient(httpz.OptRequestTimeout(clientRequestTimeout)) + req, err := http.NewRequest(http.MethodGet, srvr.URL, nil) + require.NoError(t, err) + + resp, err := c.Do(req) + t.Log(err) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "http request not completed within") +} + +// TestOptHeaderTimeout_correct_error verifies that an HTTP request +// that fails via OptHeaderTimeout returns the correct error. +func TestOptHeaderTimeout_correct_error(t *testing.T) { + t.Parallel() + const srvrBody = `Hello World!` + serverDelay := time.Second * 2 + srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + t.Log("Server request context done") + return + case <-time.After(serverDelay): + } + _, err := w.Write([]byte(srvrBody)) + assert.NoError(t, err) + })) + t.Cleanup(srvr.Close) + + clientHeaderTimeout := time.Second * 1 + c := httpz.NewClient(httpz.OptHeaderTimeout(clientHeaderTimeout)) + req, err := http.NewRequest(http.MethodGet, srvr.URL, nil) + require.NoError(t, err) + + resp, err := c.Do(req) + t.Log(err) + require.Error(t, err) + require.Nil(t, resp) + require.Contains(t, err.Error(), "http response not received within") + + // Now let's try again, with a shorter server delay, so the + // request should succeed. + serverDelay = time.Millisecond + resp, err = c.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + got := tu.ReadToString(t, resp.Body) + require.Equal(t, srvrBody, got) +} + +// TestOptHeaderTimeout_vs_stdlib verifies that OptHeaderTimeout +// works as expected when compared to stdlib. +func TestOptHeaderTimeout_vs_stdlib(t *testing.T) { t.Parallel() const ( - headerTimeout = time.Second * 2 + headerTimeout = time.Millisecond * 200 numLines = 7 ) @@ -65,7 +136,7 @@ func TestOptHeaderTimeout(t *testing.T) { return } w.(http.Flusher).Flush() - time.Sleep(time.Second) + time.Sleep(time.Millisecond * 100) } })) t.Cleanup(slowServer.Close) diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index 209d0da3c..9877956a2 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -3,6 +3,7 @@ package httpz import ( "context" "crypto/tls" + "errors" "net/http" "time" @@ -63,10 +64,19 @@ func OptRequestTimeout(timeout time.Duration) TripFunc { } return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { ctx, cancelFn := context.WithTimeoutCause(req.Context(), timeout, - errz.Errorf("http request not completed in %s timeout", timeout)) + errz.Errorf("http request not completed within %s timeout", timeout)) defer cancelFn() - req = req.WithContext(ctx) - return next.RoundTrip(req) + resp, err := next.RoundTrip(req.WithContext(ctx)) + if err == nil { + return resp, nil + } + + if errors.Is(err, ctx.Err()) { + // The lower-down RoundTripper probably returned ctx.Err(), + // not context.Cause(), so we swap it around here. + err = context.Cause(ctx) + } + return resp, err } } @@ -91,8 +101,9 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { select { case <-ctx.Done(): case <-t.C: - cancelFn(errz.Errorf("http response not received by %s timeout", - timeout)) + cancelErr := errz.Errorf("http response not received within %s timeout", + timeout) + cancelFn(cancelErr) case <-timerCancelCh: // Stop the timer goroutine. } @@ -100,12 +111,15 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { resp, err := next.RoundTrip(req.WithContext(ctx)) close(timerCancelCh) - + if err != nil && errors.Is(err, ctx.Err()) { + // The lower-down RoundTripper probably returned ctx.Err(), + // not context.Cause(), so we swap it around here. + err = context.Cause(ctx) + } // Don't leak resources; ensure that cancelFn is eventually called. switch { case err != nil: - - // It's possible that cancelFn has already been called by the + // It's probable that cancelFn has already been called by the // timer goroutine, but we call it again just in case. cancelFn(err) case resp != nil && resp.Body != nil: diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 4f38a2c84..43b3a35b2 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -13,13 +13,12 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/a8m/tree" "github.com/a8m/tree/ostree" yaml "github.com/goccy/go-yaml" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/lg" ) diff --git a/libsq/core/progress/bars.go b/libsq/core/progress/bars.go index 313bb3c17..e53dee90e 100644 --- a/libsq/core/progress/bars.go +++ b/libsq/core/progress/bars.go @@ -7,12 +7,6 @@ import ( "github.com/dustin/go-humanize/english" mpb "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" - // NewByteCounter returns a new progress bar whose metric is the count - // of bytes processed. If the size is unknown, set arg size to -1. The caller - // is ultimately responsible for calling [Bar.Stop] on the returned Bar. - // However, the returned Bar is also added to the Progress's cleanup list, - // so it will be called automatically when the Progress is shut down, but that - // may be later than the actual conclusion of the Bar's work. ) func (p *Progress) NewByteCounter(msg string, size int64) *Bar { diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 2d389a44c..f0d8f602f 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -8,8 +8,6 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" @@ -18,6 +16,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" ) diff --git a/libsq/source/download.go b/libsq/source/download.go index faf8955da..c99e6cfa0 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -13,14 +13,13 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" - "golang.org/x/exp/maps" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go index bbba5da49..ffd32335e 100644 --- a/libsq/source/download_test.go +++ b/libsq/source/download_test.go @@ -11,13 +11,12 @@ import ( "strconv" "testing" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/tu" diff --git a/testh/tu/tu.go b/testh/tu/tu.go index d171ee0b7..e82df0821 100644 --- a/testh/tu/tu.go +++ b/testh/tu/tu.go @@ -400,6 +400,17 @@ func ReadFileToString(t testing.TB, name string) string { return s } +// ReadToString reads all bytes from r and returns them as a string. +// If r is an io.Closer, it is closed. +func ReadToString(t testing.TB, r io.Reader) string { + b, err := io.ReadAll(r) + require.NoError(t, err) + if r, ok := r.(io.Closer); ok { + require.NoError(t, r.Close()) + } + return string(b) +} + // OpenFileCount is a debugging function that returns the count // of open file handles for the current process via shelling out // to lsof. This function is skipped on Windows. From 9f57a224af1a0c124c2912a9bd20bdb365e0c064 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 14 Dec 2023 08:01:20 -0700 Subject: [PATCH 117/195] httpz errors --- libsq/core/ioz/httpz/httpz_test.go | 3 +++ libsq/core/ioz/httpz/opts.go | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/libsq/core/ioz/httpz/httpz_test.go b/libsq/core/ioz/httpz/httpz_test.go index 16347f9f6..bb6c24f8d 100644 --- a/libsq/core/ioz/httpz/httpz_test.go +++ b/libsq/core/ioz/httpz/httpz_test.go @@ -2,6 +2,7 @@ package httpz_test import ( "context" + "errors" "io" "net/http" "net/http/httptest" @@ -41,6 +42,7 @@ func TestOptRequestTimeout(t *testing.T) { require.Error(t, err) require.Nil(t, resp) require.Contains(t, err.Error(), "http request not completed within") + require.True(t, errors.Is(err, context.DeadlineExceeded)) } // TestOptHeaderTimeout_correct_error verifies that an HTTP request @@ -71,6 +73,7 @@ func TestOptHeaderTimeout_correct_error(t *testing.T) { require.Error(t, err) require.Nil(t, resp) require.Contains(t, err.Error(), "http response not received within") + require.True(t, errors.Is(err, context.DeadlineExceeded)) // Now let's try again, with a shorter server delay, so the // request should succeed. diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index 9877956a2..58e74c758 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -64,7 +64,7 @@ func OptRequestTimeout(timeout time.Duration) TripFunc { } return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { ctx, cancelFn := context.WithTimeoutCause(req.Context(), timeout, - errz.Errorf("http request not completed within %s timeout", timeout)) + errz.Wrapf(context.DeadlineExceeded, "http request not completed within %s timeout", timeout)) defer cancelFn() resp, err := next.RoundTrip(req.WithContext(ctx)) if err == nil { @@ -101,7 +101,8 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { select { case <-ctx.Done(): case <-t.C: - cancelErr := errz.Errorf("http response not received within %s timeout", + cancelErr := errz.Wrapf(context.DeadlineExceeded, + "http response not received within %s timeout", timeout) cancelFn(cancelErr) case <-timerCancelCh: From 6f321982f87360def88615f582e6ac36a725577c Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 14 Dec 2023 11:49:11 -0700 Subject: [PATCH 118/195] refactoring slog/slogt --- cli/cli.go | 7 +- cli/cmd_x.go | 3 +- cli/error.go | 37 ++++++---- cli/output.go | 10 +-- libsq/core/ioz/download/cache.go | 102 +++++++++++++++++++++++++-- libsq/core/ioz/download/download.go | 21 +----- libsq/core/ioz/httpz/httpz.go | 11 ++- libsq/core/ioz/httpz/httpz_test.go | 6 +- libsq/core/ioz/httpz/opts.go | 66 ++++++++++++----- libsq/core/ioz/ioz.go | 17 +++-- libsq/core/lg/devlog/tint/handler.go | 20 +++--- libsq/core/lg/lgt/lgt.go | 19 +++++ libsq/core/progress/progress.go | 12 +++- testh/testh.go | 7 -- 14 files changed, 240 insertions(+), 98 deletions(-) create mode 100644 libsq/core/lg/lgt/lgt.go diff --git a/cli/cli.go b/cli/cli.go index 854339531..021dbd120 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -172,13 +172,12 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { // Execute rootCmd; cobra will find the appropriate // sub-command, and ultimately execute that command. err = rootCmd.ExecuteContext(ctx) - log.Warn("Closing run", lga.Err, err) lg.WarnIfCloseError(log, "Problem closing run", ru) if err != nil { - ctx2 := rootCmd.Context() // FIXME: delete - _ = ctx2 + //ctx2 := rootCmd.Context() // FIXME: delete + //_ = ctx2 - printError(ctx2, ru, err) + printError(ctx, ru, err) } return err diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 944392f44..ec21cb0bb 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -162,7 +162,8 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { return err } - dl, err := download.New(fakeSrc.Handle, httpz.NewDefaultClient(), u.String(), cacheDir) + c := httpz.NewClient(httpz.DefaultUserAgent, httpz.OptRequestTimeout(time.Second*5)) + dl, err := download.New(fakeSrc.Handle, c, u.String(), cacheDir) if err != nil { return err } diff --git a/cli/error.go b/cli/error.go index 504b09512..d42dd9183 100644 --- a/cli/error.go +++ b/cli/error.go @@ -26,10 +26,7 @@ import ( // ru or any of its fields are nil). func printError(ctx context.Context, ru *run.Run, err error) { log := lg.FromContext(ctx) - log.Warn("printError called", lga.Err, err) // FIXME: delete - // debug.PrintStack() - // stack := errz.Stack(err) - // fmt.Fprintln(ru.Out, "printError stack", "stack", stack) + if err == nil { log.Warn("printError called with nil error") return @@ -40,15 +37,6 @@ func printError(ctx context.Context, ru *run.Run, err error) { return } - switch { - // Friendlier messages for context errors. - default: - case errors.Is(err, context.Canceled): - err = errz.New("canceled") - case errors.Is(err, context.DeadlineExceeded): - err = errz.Wrap(err, "timeout") - } - var cmd *cobra.Command if ru != nil { cmd = ru.Cmd @@ -65,7 +53,7 @@ func printError(ctx context.Context, ru *run.Run, err error) { } else { log.Error("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName) } - + err = humanizeContextErr(err) wrtrs := ru.Writers if wrtrs != nil && wrtrs.Error != nil { // If we have an errorWriter, we print to it @@ -77,6 +65,8 @@ func printError(ctx context.Context, ru *run.Run, err error) { // Else we don't have an errorWriter, so we fall through } + err = humanizeContextErr(err) + // If we get this far, something went badly wrong in bootstrap // (probably the config is corrupt). // At this point, we could just print err to os.Stderr and be done. @@ -165,3 +155,22 @@ func panicOn(err error) { panic(err) } } + +// humanizeContextErr returns a friendlier error message +// for context errors. +func humanizeContextErr(err error) error { + if err == nil { + return nil + } + + switch { + // Friendlier messages for context errors. + default: + case errors.Is(err, context.Canceled): + err = errz.New("canceled") + case errors.Is(err, context.DeadlineExceeded): + err = errz.New("timeout") + } + + return err +} diff --git a/cli/output.go b/cli/output.go index 51b9b2aaa..25ff7114f 100644 --- a/cli/output.go +++ b/cli/output.go @@ -474,14 +474,14 @@ func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Option }) // On first write to stderr, we remove the progress widget. - errOut2 = ioz.NotifyOnceWriter(errOut2, func() { - lg.FromContext(ctx).Debug("Error stream is being written to; removing progress widget") - pb.Stop() - }) // FIXME: delete + //errOut2 = ioz.NotifyOnceWriter(errOut2, func() { + // lg.FromContext(ctx).Debug("Error stream is being written to; removing progress widget") + // pb.Stop() + //}) // FIXME: delete cmd.SetContext(progress.NewContext(ctx, pb)) } - logFrom(cmd).Debug("Constructed output.Printing", lga.Val, pr) + lg.FromContext(cmd.Context()).Debug("Constructed output.Printing", lga.Val, pr) return pr, out2, errOut2 } diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index b48b5c7da..743a13d10 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -55,17 +55,72 @@ func (c *cache) paths(req *http.Request) (header, body, checksum string) { filepath.Join(c.dir, req.Method+"_checksum.txt") } -// exists returns true if the cache contains a response for req. +// exists returns true if the cache exists and is consistent. +// If it's inconsistent, it will be automatically cleared. +// See also: clearIfInconsistent. func (c *cache) exists(req *http.Request) bool { c.mu.Lock() defer c.mu.Unlock() + if err := c.clearIfInconsistent(req); err != nil { + lg.FromContext(req.Context()).Error("Failed to clear inconsistent cache", + lga.Err, err, lga.Dir, c.dir) + return false + } + fpHeader, _, _ := c.paths(req) fi, err := os.Stat(fpHeader) if err != nil { return false } - return fi.Size() > 0 + + if fi.Size() == 0 { + return false + } + + _, ok := c.checksumsMatch(req) + return ok +} + +// clearIfInconsistent deletes the cache if it is inconsistent. +func (c *cache) clearIfInconsistent(req *http.Request) error { + if !ioz.DirExists(c.dir) { + return nil + } + + entries, err := ioz.ReadDir(c.dir, false, false, false) + if err != nil { + return err + } + + if len(entries) == 0 { + // If it's an empty cache, that's consistent. + return nil + } + + // We know that there's at least one file in the cache. + // To be consistent, all three cache files must exist. + inconsistent := false + fpHeader, fpBody, fpChecksum := c.paths(req) + for _, fp := range []string{fpHeader, fpBody, fpChecksum} { + if !ioz.FileAccessible(fp) { + inconsistent = true + break + } + } + + if !inconsistent { + // All three cache files exist. Verify that checksums match. + if _, ok := c.checksumsMatch(req); !ok { + inconsistent = true + } + } + + if inconsistent { + lg.FromContext(req.Context()).Warn("Deleting inconsistent cache", lga.Dir, c.dir) + return c.doClear(req.Context()) + } + return nil } // Get returns the cached http.Response for req if present, and nil @@ -81,6 +136,13 @@ func (c *cache) get(ctx context.Context, req *http.Request) (*http.Response, err return nil, nil } + if _, ok := c.checksumsMatch(req); !ok { + // If the checksums don't match, it's a nil, nil situation. + + // REVISIT: should we clear the cache here? + return nil, nil + } + headerBytes, err := os.ReadFile(fpHeader) if err != nil { return nil, errz.Wrap(err, "failed to read cached response header file") @@ -115,8 +177,8 @@ func (c *cache) get(ctx context.Context, req *http.Request) (*http.Response, err return resp, nil } -// checksum returns the checksum of the cached body file, if available. -func (c *cache) checksum(req *http.Request) (sum checksum.Checksum, ok bool) { +// checksum returns the contents of the cached checksum file, if available. +func (c *cache) cachedChecksum(req *http.Request) (sum checksum.Checksum, ok bool) { if c == nil || req == nil { return "", false } @@ -142,6 +204,29 @@ func (c *cache) checksum(req *http.Request) (sum checksum.Checksum, ok bool) { return sum, ok } +// checksumsMatch returns true (and the valid checksum) if there is a cached +// checksum file for req, and there is a cached response body file, and a fresh +// checksum calculated from that body file matches the cached checksum. +func (c *cache) checksumsMatch(req *http.Request) (sum checksum.Checksum, ok bool) { + sum, ok = c.cachedChecksum(req) + if !ok { + return "", false + } + + _, fpBody, _ := c.paths(req) + calculatedSum, err := checksum.ForFile(fpBody) + if err != nil { + return "", false + } + + if calculatedSum != sum { + lg.FromContext(req.Context()).Warn("Inconsistent cache: checksums don't match", lga.Dir, c.dir) + return "", false + } + + return sum, true +} + // clear deletes the cache entries from disk. func (c *cache) clear(ctx context.Context) error { if c == nil { @@ -192,12 +277,14 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, defer func() { lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - if err == nil { - return - } if err != nil && copyWrtr != nil { copyWrtr.Error(err) } + + if err != nil { + log.Warn("Deleting cache because cache write failed", lga.Err, err, lga.Dir, c.dir) + lg.WarnIfError(log, msgDeleteCache, c.doClear(ctx)) + } }() if err = ioz.RequireDir(c.dir); err != nil { @@ -238,6 +325,7 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, err = errz.Err(err) log.Error("Cache write: io.Copy failed", lga.Err, err) lg.WarnIfCloseError(log, msgCloseCacheBodyFile, cacheFile) + cacheFile = nil return err } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index de9273c8e..f0006e8fb 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -18,7 +18,6 @@ import ( "path/filepath" "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/ioz/httpz" @@ -208,6 +207,7 @@ func (dl *Download) get(req *http.Request, h Handler) { for _, header := range endToEndHeaders { cachedResp.Header[header] = resp.Header[header] } + lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) resp = cachedResp } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header) { @@ -371,24 +371,7 @@ func (dl *Download) Checksum(ctx context.Context) (sum checksum.Checksum, ok boo } req := dl.mustRequest(ctx) - - _, _, fp := dl.cache.paths(req) - if !ioz.FileAccessible(fp) { - return "", false - } - - sums, err := checksum.ReadFile(fp) - if err != nil { - lg.FromContext(ctx).Warn("Failed to read checksum file", lga.File, fp, lga.Err, err) - return "", false - } - - if len(sums) != 1 { - return "", false - } - - sum, ok = sums["body"] - return sum, ok + return dl.cache.cachedChecksum(req) } func (dl *Download) isCacheable(req *http.Request) bool { diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index 2a90c8d95..149669b38 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -5,11 +5,14 @@ // Design note: this package contains generally fairly straightforward HTTP // functionality, but the Opt / TripFunc config mechanism is a bit // experimental. And probably tries to be a bit too clever. It may change. +// +// And one last thing: remember kids, ALWAYS close your response bodies. package httpz import ( "bufio" "fmt" + "github.com/neilotoole/sq/libsq/core/stringz" "io" "log/slog" "mime" @@ -19,18 +22,14 @@ import ( "path/filepath" "strconv" "strings" - "time" - - "github.com/neilotoole/sq/cli/buildinfo" - "github.com/neilotoole/sq/libsq/core/stringz" ) // NewDefaultClient invokes NewClient with default settings. func NewDefaultClient() *http.Client { return NewClient( OptInsecureSkipVerify(false), - OptUserAgent(buildinfo.Get().UserAgent()), - OptHeaderTimeout(time.Second*5), + DefaultUserAgent, + DefaultHeaderTimeout, ) } diff --git a/libsq/core/ioz/httpz/httpz_test.go b/libsq/core/ioz/httpz/httpz_test.go index bb6c24f8d..a42e4ebd0 100644 --- a/libsq/core/ioz/httpz/httpz_test.go +++ b/libsq/core/ioz/httpz/httpz_test.go @@ -3,6 +3,8 @@ package httpz_test import ( "context" "errors" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "io" "net/http" "net/http/httptest" @@ -49,6 +51,8 @@ func TestOptRequestTimeout(t *testing.T) { // that fails via OptHeaderTimeout returns the correct error. func TestOptHeaderTimeout_correct_error(t *testing.T) { t.Parallel() + ctx := lg.NewContext(context.Background(), lgt.New(t)) + const srvrBody = `Hello World!` serverDelay := time.Second * 2 srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -65,7 +69,7 @@ func TestOptHeaderTimeout_correct_error(t *testing.T) { clientHeaderTimeout := time.Second * 1 c := httpz.NewClient(httpz.OptHeaderTimeout(clientHeaderTimeout)) - req, err := http.NewRequest(http.MethodGet, srvr.URL, nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvr.URL, nil) require.NoError(t, err) resp, err := c.Do(req) diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index 58e74c758..a9b46ad5a 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -4,10 +4,12 @@ import ( "context" "crypto/tls" "errors" + "github.com/neilotoole/sq/cli/buildinfo" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" "net/http" "time" - "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" ) @@ -43,7 +45,8 @@ func (v minTLSVersion) apply(tr *http.Transport) { } } -// DefaultTLSVersion is the default minimum TLS version used by [NewClient]. +// DefaultTLSVersion is the default minimum TLS version, +// as used by [NewDefaultClient]. var DefaultTLSVersion = minTLSVersion(tls.VersionTLS10) // OptUserAgent is passed to [NewClient] to set the User-Agent header. @@ -54,6 +57,10 @@ func OptUserAgent(ua string) TripFunc { } } +// DefaultUserAgent is the default User-Agent header value, +// as used by [NewDefaultClient]. +var DefaultUserAgent = OptUserAgent(buildinfo.Get().UserAgent()) + // OptRequestTimeout is passed to [NewClient] to set the total request timeout. // If timeout is zero, this is a no-op. // @@ -63,19 +70,37 @@ func OptRequestTimeout(timeout time.Duration) TripFunc { return NopTripFunc } return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { - ctx, cancelFn := context.WithTimeoutCause(req.Context(), timeout, - errz.Wrapf(context.DeadlineExceeded, "http request not completed within %s timeout", timeout)) - defer cancelFn() + timeoutErr := errors.New("http request timeout") + ctx, cancelFn := context.WithTimeoutCause(req.Context(), timeout, timeoutErr) + resp, err := next.RoundTrip(req.WithContext(ctx)) if err == nil { + if resp.Body == nil { + // Shouldn't happen, but just in case. + cancelFn() + } else { + // Wrap resp.Body with a ReadCloserNotifier, so that cancelFn + // is called when the body is closed. + resp.Body = ioz.ReadCloserNotifier(resp.Body, func(err error) { + if errors.Is(context.Cause(ctx), timeoutErr) { + lg.FromContext(ctx).Warn("HTTP request not completed within timeout", + lga.Timeout, timeout, lga.URL, req.URL.String()) + } + + cancelFn() + }) + } return resp, nil } - if errors.Is(err, ctx.Err()) { - // The lower-down RoundTripper probably returned ctx.Err(), - // not context.Cause(), so we swap it around here. - err = context.Cause(ctx) + // We've got an error + defer cancelFn() + + if errors.Is(context.Cause(ctx), timeoutErr) { + lg.FromContext(ctx).Warn("HTTP request not completed within timeout XYZ", // FIXME: delete + lga.Timeout, timeout, lga.URL, req.URL.String()) } + return resp, err } } @@ -101,10 +126,11 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { select { case <-ctx.Done(): case <-t.C: - cancelErr := errz.Wrapf(context.DeadlineExceeded, - "http response not received within %s timeout", - timeout) - cancelFn(cancelErr) + log := lg.FromContext(ctx) + _ = log + lg.FromContext(ctx).Warn("HTTP header response not received within timeout", + lga.Timeout, timeout, lga.URL, req.URL.String()) + cancelFn(context.DeadlineExceeded) case <-timerCancelCh: // Stop the timer goroutine. } @@ -115,24 +141,30 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { if err != nil && errors.Is(err, ctx.Err()) { // The lower-down RoundTripper probably returned ctx.Err(), // not context.Cause(), so we swap it around here. - err = context.Cause(ctx) + if cause := context.Cause(ctx); cause != nil { + err = cause + } } // Don't leak resources; ensure that cancelFn is eventually called. switch { case err != nil: // It's probable that cancelFn has already been called by the // timer goroutine, but we call it again just in case. - cancelFn(err) + cancelFn(context.DeadlineExceeded) case resp != nil && resp.Body != nil: // Wrap resp.Body with a ReadCloserNotifier, so that cancelFn // is called when the body is closed. - resp.Body = ioz.ReadCloserNotifier(resp.Body, cancelFn) + resp.Body = ioz.ReadCloserNotifier(resp.Body, func(error) { cancelFn(context.DeadlineExceeded) }) default: // Not sure if this can actually happen, but just in case. - cancelFn(context.Canceled) + cancelFn(context.DeadlineExceeded) } return resp, err } } + +// DefaultHeaderTimeout is the default header timeout as used +// by [NewDefaultClient]. +var DefaultHeaderTimeout = OptHeaderTimeout(time.Second * 5) diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 43b3a35b2..9cedb1128 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -390,7 +390,9 @@ func PrintTree(w io.Writer, loc string, showSize, colorize bool) error { // ReadCloserNotifier returns a new io.ReadCloser that invokes fn // after Close is called, passing along any error from Close. -// If rc or fn is nil, rc is returned. +// If rc or fn is nil, rc is returned. Note that any subsequent +// calls to Close are no-op, and return the same error (if any) +// as the first invocation of Close. func ReadCloserNotifier(rc io.ReadCloser, fn func(closeErr error)) io.ReadCloser { if rc == nil || fn == nil { return rc @@ -399,14 +401,19 @@ func ReadCloserNotifier(rc io.ReadCloser, fn func(closeErr error)) io.ReadCloser } type readCloserNotifier struct { - fn func(error) + once sync.Once + closeErr error + fn func(error) io.ReadCloser } func (c *readCloserNotifier) Close() error { - err := c.ReadCloser.Close() - c.fn(err) - return err + c.once.Do(func() { + c.closeErr = c.ReadCloser.Close() + c.fn(c.closeErr) + //c.closeErr = errz.New("huzzah") // FIXME: delete + }) + return c.closeErr } // WriteToFile writes the contents of r to fp. If fp doesn't exist, diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index e8a79070e..2f868054f 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -70,14 +70,16 @@ import ( // ANSI modes // See: https://gist.github.com/JBlond/2fea43a3049b38287e5e9cefc87b2124 const ( - ansiReset = "\033[0m" - ansiFaint = "\033[2m" - ansiResetFaint = "\033[22m" - ansiBrightRed = "\033[91m" - ansiBrightGreen = "\033[92m" - ansiBrightYellow = "\033[93m" - ansiBlue = "\033[34m" - ansiBrightRedFaint = "\033[91;2m" + ansiReset = "\033[0m" + ansiFaint = "\033[2m" + ansiResetFaint = "\033[22m" + ansiBrightRed = "\033[91m" + ansiBrightGreen = "\033[92m" + ansiBrightGreenBold = "\033[1;92m" + ansiBrightYellow = "\033[93m" + ansiBlue = "\033[34m" + ansiBrightBlue = "\033[94m" + ansiBrightRedFaint = "\033[91;2m" ) const errKey = "err" @@ -229,7 +231,7 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { case slog.LevelError: msgColor = ansiBrightRed case slog.LevelInfo: - msgColor = ansiBlue + msgColor = ansiBrightGreenBold } // write message if rep == nil { diff --git a/libsq/core/lg/lgt/lgt.go b/libsq/core/lg/lgt/lgt.go new file mode 100644 index 000000000..a8772ec10 --- /dev/null +++ b/libsq/core/lg/lgt/lgt.go @@ -0,0 +1,19 @@ +// Package lgt provides a mechanism for getting a *slog.Logger +// that outputs to testing.T. +package lgt + +import ( + "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/lg/devlog" + "io" + "log/slog" +) + +func init() { //nolint:gochecknoinits + slogt.Default = slogt.Factory(func(w io.Writer) slog.Handler { + return devlog.NewHandler(w, slog.LevelDebug) + }) +} + +// New delegates to slogt.New. +var New = slogt.New diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 329d5007f..17dbb85ca 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -211,8 +211,8 @@ func (p *Progress) Stop() { func (p *Progress) doStop() { p.stopOnce.Do(func() { p.pcInitFn = nil - lg.FromContext(p.ctx).Warn("Stopping progress widget") - defer lg.FromContext(p.ctx).Warn("Stopped progress widget") + lg.FromContext(p.ctx).Debug("Stopping progress widget") + defer lg.FromContext(p.ctx).Debug("Stopped progress widget") if p.pc == nil { close(p.stoppedCh) close(p.refreshCh) @@ -299,7 +299,11 @@ func (p *Progress) newBar(msg string, total int64, barStoppedCh: make(chan struct{}), } b.barInitFn = func() { + p.mu.Lock() // FIXME: not too sure about locking here? + defer p.mu.Unlock() + select { + case <-p.ctx.Done(): case <-p.stoppedCh: return case <-b.barStoppedCh: @@ -307,6 +311,9 @@ func (p *Progress) newBar(msg string, total int64, default: } + // REVISIT: It shouldn't be the case that it's possible that the + // progress has already been stopped. If it is stopped, the call + // below will panic. Maybe consider wrapping the call in a recover? b.bar = p.pc.New(total, style, mpb.BarWidth(barWidth), @@ -318,7 +325,6 @@ func (p *Progress) newBar(msg string, total int64, ) b.bar.IncrBy(int(b.incrStash.Load())) b.incrStash = nil - // b.incrStash.Store(0) } b.delayCh = barRenderDelay(b, p.delay) diff --git a/testh/testh.go b/testh/testh.go index c342547a2..23240142b 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -41,7 +41,6 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/devlog" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" @@ -64,12 +63,6 @@ import ( // or not. const defaultDBOpenTimeout = time.Second * 5 -func init() { //nolint:gochecknoinits - slogt.Default = slogt.Factory(func(w io.Writer) slog.Handler { - return devlog.NewHandler(w, slog.LevelDebug) - }) -} - // Option is a functional option type used with New to // configure the helper. type Option func(h *Helper) From 9666a269707a7bbd7b41b5b27e228e6bc649a800 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 14 Dec 2023 11:53:15 -0700 Subject: [PATCH 119/195] refactoring slog/slogt --- cli/complete_location_test.go | 5 ++--- cli/complete_test.go | 5 ++--- .../yamlstore/upgrades/v0.34.0/upgrade_test.go | 5 ++--- cli/hostinfo/hostinfo_test.go | 5 ++--- cli/options_test.go | 5 ++--- cli/output/jsonw/jsonw_test.go | 5 ++--- cli/testrun/testrun.go | 5 ++--- drivers/json/json_test.go | 5 ++--- drivers/mysql/metadata_test.go | 5 ++--- drivers/sqlite3/metadata_test.go | 5 ++--- libsq/ast/node_test.go | 8 ++++---- libsq/ast/parser_test.go | 13 ++++++------- libsq/ast/selector_test.go | 5 ++--- libsq/ast/walker_test.go | 4 ++-- libsq/core/errz/errz_test.go | 5 ++--- libsq/core/ioz/download/download_test.go | 7 +++---- libsq/core/ioz/httpz/httpz.go | 3 ++- libsq/core/ioz/httpz/httpz_test.go | 4 ++-- libsq/core/ioz/httpz/opts.go | 6 +++--- libsq/core/ioz/ioz.go | 2 +- libsq/core/ioz/lockfile/lockfile_test.go | 5 ++--- libsq/core/lg/lg_test.go | 5 ++--- libsq/core/lg/lgt/lgt.go | 10 ++++++---- libsq/core/options/options_test.go | 7 +++---- libsq/source/download_test.go | 5 ++--- libsq/source/files_test.go | 11 +++++------ libsq/source/internal_test.go | 5 ++--- libsq/source/source_test.go | 5 ++--- testh/testh.go | 5 ++--- 29 files changed, 73 insertions(+), 92 deletions(-) diff --git a/cli/complete_location_test.go b/cli/complete_location_test.go index 4c510e571..304331b8f 100644 --- a/cli/complete_location_test.go +++ b/cli/complete_location_test.go @@ -11,14 +11,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/cli" "github.com/neilotoole/sq/cli/testrun" "github.com/neilotoole/sq/drivers/postgres" "github.com/neilotoole/sq/drivers/sqlite3" "github.com/neilotoole/sq/drivers/sqlserver" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/testh" @@ -1337,7 +1336,7 @@ func TestDoCompleteAddLocationFile(t *testing.T) { for i, tc := range testCases { tc := tc t.Run(tu.Name(i, tc.in), func(t *testing.T) { - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) t.Logf("input: %s", tc.in) t.Logf("want: %s", tc.want) got := cli.DoCompleteAddLocationFile(ctx, tc.in) diff --git a/cli/complete_test.go b/cli/complete_test.go index ff0090303..0217ba8ef 100644 --- a/cli/complete_test.go +++ b/cli/complete_test.go @@ -10,12 +10,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/cli/cobraz" "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/testrun" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/tu" @@ -23,7 +22,7 @@ import ( // testComplete is a helper for testing cobra completion. func testComplete(t testing.TB, from *testrun.TestRun, args ...string) completion { - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) tr := testrun.New(ctx, t, from) args = append([]string{"__complete"}, args...) diff --git a/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go b/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go index d4c168b56..e90ebee64 100644 --- a/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go +++ b/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go @@ -10,8 +10,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/cli" "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/cli/config/yamlstore" @@ -21,6 +19,7 @@ import ( "github.com/neilotoole/sq/drivers/postgres" "github.com/neilotoole/sq/drivers/xlsx" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/testh" @@ -28,7 +27,7 @@ import ( ) func TestUpgrade(t *testing.T) { - log := slogt.New(t) + log := lgt.New(t) ctx := lg.NewContext(context.Background(), log) const ( diff --git a/cli/hostinfo/hostinfo_test.go b/cli/hostinfo/hostinfo_test.go index 78404992e..6cfced5a9 100644 --- a/cli/hostinfo/hostinfo_test.go +++ b/cli/hostinfo/hostinfo_test.go @@ -3,15 +3,14 @@ package hostinfo_test import ( "testing" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/cli/hostinfo" + "github.com/neilotoole/sq/libsq/core/lg/lgt" ) func TestGet(t *testing.T) { info := hostinfo.Get() - log := slogt.New(t) + log := lgt.New(t) log.Debug("Via slog", "sys", info) t.Logf("Via string: %s", info.String()) diff --git a/cli/options_test.go b/cli/options_test.go index f27421ffd..31cc8fcd7 100644 --- a/cli/options_test.go +++ b/cli/options_test.go @@ -5,14 +5,13 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/cli" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/options" ) func TestRegisterDefaultOpts(t *testing.T) { - log := slogt.New(t) + log := lgt.New(t) reg := &options.Registry{} log.Debug("options.Registry (before)", "reg", reg) diff --git a/cli/output/jsonw/jsonw_test.go b/cli/output/jsonw/jsonw_test.go index 79ef9f12f..d93364365 100644 --- a/cli/output/jsonw/jsonw_test.go +++ b/cli/output/jsonw/jsonw_test.go @@ -12,11 +12,10 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/cli/output" "github.com/neilotoole/sq/cli/output/jsonw" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/fixt" @@ -215,7 +214,7 @@ func TestErrorWriter(t *testing.T) { pr.Compact = !tc.pretty pr.EnableColor(tc.color) - errw := jsonw.NewErrorWriter(slogt.New(t), buf, pr) + errw := jsonw.NewErrorWriter(lgt.New(t), buf, pr) errw.Error(errz.New("err1")) got := buf.String() diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index 2e8836a74..12dd495b7 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -14,14 +14,13 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/cli" "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/cli/config/yamlstore" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/source" ) @@ -55,7 +54,7 @@ func New(ctx context.Context, t testing.TB, from *TestRun) *TestRun { } if !lg.InContext(ctx) { - ctx = lg.NewContext(ctx, slogt.New(t)) + ctx = lg.NewContext(ctx, lgt.New(t)) } tr := &TestRun{T: t, Context: ctx, mu: &sync.Mutex{}} diff --git a/drivers/json/json_test.go b/drivers/json/json_test.go index da4a74eb6..e6aad1721 100644 --- a/drivers/json/json_test.go +++ b/drivers/json/json_test.go @@ -9,10 +9,9 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/drivers/json" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/testh/tu" @@ -96,7 +95,7 @@ func TestDriverDetectorFuncs(t *testing.T) { openFn := func(ctx context.Context) (io.ReadCloser, error) { return os.Open(filepath.Join("testdata", tc.fname)) } detectFn := detectFns[tc.fn] - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) gotType, gotScore, gotErr := detectFn(ctx, openFn) if tc.wantErr { diff --git a/drivers/mysql/metadata_test.go b/drivers/mysql/metadata_test.go index c5bff62ab..e621b3f52 100644 --- a/drivers/mysql/metadata_test.go +++ b/drivers/mysql/metadata_test.go @@ -6,11 +6,10 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/drivers/mysql" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/sakila" ) @@ -18,7 +17,7 @@ import ( func TestKindFromDBTypeName(t *testing.T) { t.Parallel() - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) testCases := map[string]kind.Kind{ "": kind.Unknown, diff --git a/drivers/sqlite3/metadata_test.go b/drivers/sqlite3/metadata_test.go index 7fd73364a..209bd121b 100644 --- a/drivers/sqlite3/metadata_test.go +++ b/drivers/sqlite3/metadata_test.go @@ -10,12 +10,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/drivers/sqlite3" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/core/tablefq" "github.com/neilotoole/sq/libsq/source/metadata" @@ -94,7 +93,7 @@ func TestCurrentTime(t *testing.T) { func TestKindFromDBTypeName(t *testing.T) { t.Parallel() - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) testCases := map[string]kind.Kind{ "": kind.Bytes, diff --git a/libsq/ast/node_test.go b/libsq/ast/node_test.go index 38ffa08a4..1374047de 100644 --- a/libsq/ast/node_test.go +++ b/libsq/ast/node_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/lg/lgt" ) func TestChildIndex(t *testing.T) { @@ -13,7 +13,7 @@ func TestChildIndex(t *testing.T) { p := getSLQParser(q1) query := p.Query() - ast, err := buildAST(slogt.New(t), query) + ast, err := buildAST(lgt.New(t), query) require.Nil(t, err) require.NotNil(t, ast) require.Equal(t, 4, len(ast.Segments())) @@ -35,7 +35,7 @@ func TestNodesWithType(t *testing.T) { func TestNodePrevNextSibling(t *testing.T) { const in = `@sakila | .actor | .actor_id == 2` - log := slogt.New(t) + log := lgt.New(t) a, err := Parse(log, in) require.NoError(t, err) @@ -83,7 +83,7 @@ func TestNodeUnwrap(t *testing.T) { func TestFindNodes(t *testing.T) { const in = `@sakila | .actor | .actor_id == 2 | .actor_id, .first_name, .last_name` - a, err := Parse(slogt.New(t), in) + a, err := Parse(lgt.New(t), in) require.NoError(t, err) handles := FindNodes[*HandleNode](a) diff --git a/libsq/ast/parser_test.go b/libsq/ast/parser_test.go index 86f3af84e..b59b60f69 100644 --- a/libsq/ast/parser_test.go +++ b/libsq/ast/parser_test.go @@ -6,9 +6,8 @@ import ( antlr "github.com/antlr4-go/antlr/v4" "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/ast/internal/slq" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/testh/tu" ) @@ -24,7 +23,7 @@ func getSLQParser(input string) *slq.SLQParser { // buildInitialAST returns a new AST created by parseTreeVisitor. The AST has not // yet been processed. func buildInitialAST(t *testing.T, input string) (*AST, error) { - log := slogt.New(t) + log := lgt.New(t) p := getSLQParser(input) q, _ := p.Query().(*slq.QueryContext) @@ -39,7 +38,7 @@ func buildInitialAST(t *testing.T, input string) (*AST, error) { // mustParse builds a full AST from the input SLQ, or fails on any error. func mustParse(t *testing.T, input string) *AST { - log := slogt.New(t) + log := lgt.New(t) ast, err := Parse(log, input) require.NoError(t, err) @@ -48,7 +47,7 @@ func mustParse(t *testing.T, input string) *AST { func TestSimpleQuery(t *testing.T) { const q1 = `@mydb1 | .user | .uid, .username` - log := slogt.New(t) + log := lgt.New(t) ptree, err := parseSLQ(log, q1) require.Nil(t, err) @@ -83,7 +82,7 @@ func TestParseBuild(t *testing.T) { for i, tc := range testCases { t.Run(tu.Name(i, tc.name), func(t *testing.T) { t.Logf(tc.in) - log := slogt.New(t) + log := lgt.New(t) ptree, err := parseSLQ(log, tc.in) require.Nil(t, err) @@ -97,7 +96,7 @@ func TestParseBuild(t *testing.T) { } func TestInspector_FindWhereClauses(t *testing.T) { - log := slogt.New(t) + log := lgt.New(t) // Verify that "where(.uid > 4)" becomes a WHERE clause. const input = "@my1 | .actor | where(.uid > 4) | .uid, .username" diff --git a/libsq/ast/selector_test.go b/libsq/ast/selector_test.go index a98e5fc47..62c107e6e 100644 --- a/libsq/ast/selector_test.go +++ b/libsq/ast/selector_test.go @@ -5,8 +5,7 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/testh/tu" ) @@ -31,7 +30,7 @@ func TestColumnAlias(t *testing.T) { t.Run(tu.Name(tc.in), func(t *testing.T) { t.Parallel() - log := slogt.New(t) + log := lgt.New(t) ast, err := Parse(log, tc.in) if tc.wantErr { diff --git a/libsq/ast/walker_test.go b/libsq/ast/walker_test.go index 0c754f4fc..11368f031 100644 --- a/libsq/ast/walker_test.go +++ b/libsq/ast/walker_test.go @@ -5,7 +5,7 @@ import ( "github.com/stretchr/testify/assert" - "github.com/neilotoole/slogt" + "github.com/neilotoole/sq/libsq/core/lg/lgt" ) func TestWalker(t *testing.T) { @@ -13,7 +13,7 @@ func TestWalker(t *testing.T) { p := getSLQParser(q1) query := p.Query() - ast, err := buildAST(slogt.New(t), query) + ast, err := buildAST(lgt.New(t), query) assert.Nil(t, err) assert.NotNil(t, ast) diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 838b106fb..130294add 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -11,10 +11,9 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/stringz" ) @@ -49,7 +48,7 @@ func (e *CustomError) Error() string { } func TestLogError_LogValue(t *testing.T) { - log := slogt.New(t) + log := lgt.New(t) nakedErr := sql.ErrNoRows log.Debug("naked", lga.Err, nakedErr) diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index 5dc4c6a06..26d7e866e 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -13,12 +13,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/testh/tu" ) @@ -65,7 +64,7 @@ func TestDownload_redirect(t *testing.T) { // FIXME: switch back to temp dir cacheDir := filepath.Join("testdata", "download", tu.Name(t.Name())) - log := slogt.New(t) + log := lgt.New(t) var srvr *httptest.Server srvr = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log := log.With("origin", "server") @@ -167,7 +166,7 @@ func TestDownload_redirect(t *testing.T) { //loc := srvr.URL + "/actual" func TestDownload_New(t *testing.T) { - log := slogt.New(t) + log := lgt.New(t) ctx := lg.NewContext(context.Background(), log) const dlURL = urlActorCSV diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index 149669b38..c8858eee5 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -12,7 +12,6 @@ package httpz import ( "bufio" "fmt" - "github.com/neilotoole/sq/libsq/core/stringz" "io" "log/slog" "mime" @@ -22,6 +21,8 @@ import ( "path/filepath" "strconv" "strings" + + "github.com/neilotoole/sq/libsq/core/stringz" ) // NewDefaultClient invokes NewClient with default settings. diff --git a/libsq/core/ioz/httpz/httpz_test.go b/libsq/core/ioz/httpz/httpz_test.go index a42e4ebd0..c1de8d8f7 100644 --- a/libsq/core/ioz/httpz/httpz_test.go +++ b/libsq/core/ioz/httpz/httpz_test.go @@ -3,8 +3,6 @@ package httpz_test import ( "context" "errors" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lgt" "io" "net/http" "net/http/httptest" @@ -15,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/testh/tu" ) diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index a9b46ad5a..ba89ba22e 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -4,13 +4,13 @@ import ( "context" "crypto/tls" "errors" - "github.com/neilotoole/sq/cli/buildinfo" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" "net/http" "time" + "github.com/neilotoole/sq/cli/buildinfo" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" ) // Opt is an option that can be passed to [NewClient] to diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 9cedb1128..48fff5bd1 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -411,7 +411,7 @@ func (c *readCloserNotifier) Close() error { c.once.Do(func() { c.closeErr = c.ReadCloser.Close() c.fn(c.closeErr) - //c.closeErr = errz.New("huzzah") // FIXME: delete + // c.closeErr = errz.New("huzzah") // FIXME: delete }) return c.closeErr } diff --git a/libsq/core/ioz/lockfile/lockfile_test.go b/libsq/core/ioz/lockfile/lockfile_test.go index 46225d2e6..1ac0cf065 100644 --- a/libsq/core/ioz/lockfile/lockfile_test.go +++ b/libsq/core/ioz/lockfile/lockfile_test.go @@ -8,15 +8,14 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" ) // FIXME: Duh, this can't work, because we're in the same pid. func TestLockfile(t *testing.T) { - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) pidfile := filepath.Join(t.TempDir(), "lock.pid") lock, err := lockfile.New(pidfile) diff --git a/libsq/core/lg/lg_test.go b/libsq/core/lg/lg_test.go index e1468e8ab..03b4dd56b 100644 --- a/libsq/core/lg/lg_test.go +++ b/libsq/core/lg/lg_test.go @@ -4,14 +4,13 @@ import ( "context" "testing" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" ) func TestContext(t *testing.T) { ctx := context.Background() - log := slogt.New(t) + log := lgt.New(t) ctx = lg.NewContext(ctx, log) log = lg.FromContext(ctx) diff --git a/libsq/core/lg/lgt/lgt.go b/libsq/core/lg/lgt/lgt.go index a8772ec10..168292069 100644 --- a/libsq/core/lg/lgt/lgt.go +++ b/libsq/core/lg/lgt/lgt.go @@ -1,12 +1,14 @@ // Package lgt provides a mechanism for getting a *slog.Logger -// that outputs to testing.T. +// that outputs to testing.T. See [lgt.New]. package lgt import ( - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/lg/devlog" "io" "log/slog" + + "github.com/neilotoole/slogt" + + "github.com/neilotoole/sq/libsq/core/lg/devlog" ) func init() { //nolint:gochecknoinits @@ -15,5 +17,5 @@ func init() { //nolint:gochecknoinits }) } -// New delegates to slogt.New. +// New delegates to [slogt.New]. var New = slogt.New diff --git a/libsq/core/options/options_test.go b/libsq/core/options/options_test.go index 77fac82b5..2401d0fd4 100644 --- a/libsq/core/options/options_test.go +++ b/libsq/core/options/options_test.go @@ -8,12 +8,11 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/cli" "github.com/neilotoole/sq/cli/output/format" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/testh/tu" @@ -24,7 +23,7 @@ type config struct { } func TestOptions(t *testing.T) { - log := slogt.New(t) + log := lgt.New(t) b, err := os.ReadFile("testdata/good.01.yml") require.NoError(t, err) @@ -151,7 +150,7 @@ func TestMerge(t *testing.T) { func TestOptions_LogValue(t *testing.T) { o1 := options.Options{"a": 1, "b": true, "c": "hello"} - log := slogt.New(t) + log := lgt.New(t) log.Debug("Logging options", lga.Opts, o1) } diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go index ffd32335e..76cd49ff7 100644 --- a/libsq/source/download_test.go +++ b/libsq/source/download_test.go @@ -14,10 +14,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/tu" ) @@ -48,7 +47,7 @@ func TestFetchHTTPHeader_sqio(t *testing.T) { } func TestDownloader_Download(t *testing.T) { - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) const dlURL = urlActorCSV const wantContentLength = sizeActorCSV u, err := url.Parse(dlURL) diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index f104581c0..e5b22c2b0 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -11,8 +11,6 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/drivers/csv" "github.com/neilotoole/sq/drivers/mysql" "github.com/neilotoole/sq/drivers/postgres" @@ -20,6 +18,7 @@ import ( "github.com/neilotoole/sq/drivers/sqlserver" "github.com/neilotoole/sq/drivers/xlsx" "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" @@ -55,7 +54,7 @@ func TestFiles_Type(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tu.Name(source.RedactLocation(tc.loc)), func(t *testing.T) { - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) require.NoError(t, err) @@ -98,7 +97,7 @@ func TestFiles_DetectType(t *testing.T) { tc := tc t.Run(filepath.Base(tc.loc), func(t *testing.T) { - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -133,7 +132,7 @@ func TestDetectMagicNumber(t *testing.T) { t.Run(filepath.Base(tc.loc), func(t *testing.T) { rFn := func(ctx context.Context) (io.ReadCloser, error) { return os.Open(tc.loc) } - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) typ, score, err := source.DetectMagicNumber(ctx, rFn) if tc.wantErr { @@ -149,7 +148,7 @@ func TestDetectMagicNumber(t *testing.T) { } func TestFiles_NewReader(t *testing.T) { - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) fpath := sakila.PathCSVActor wantBytes := proj.ReadFile(fpath) diff --git a/libsq/source/internal_test.go b/libsq/source/internal_test.go index 2632391e0..d28a84aef 100644 --- a/libsq/source/internal_test.go +++ b/libsq/source/internal_test.go @@ -9,9 +9,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" @@ -28,7 +27,7 @@ var ( ) func TestFiles_Open(t *testing.T) { - ctx := lg.NewContext(context.Background(), slogt.New(t)) + ctx := lg.NewContext(context.Background(), lgt.New(t)) fs, err := NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) require.NoError(t, err) diff --git a/libsq/source/source_test.go b/libsq/source/source_test.go index d623a8b69..1d998598a 100644 --- a/libsq/source/source_test.go +++ b/libsq/source/source_test.go @@ -5,10 +5,9 @@ import ( "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/drivers/sqlite3" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/testh/proj" @@ -546,7 +545,7 @@ func TestCollection_Tree(t *testing.T) { } func TestSource_LogValue(t *testing.T) { - log := slogt.New(t) + log := lgt.New(t) src := &source.Source{ Handle: "@sakila", diff --git a/testh/testh.go b/testh/testh.go index 23240142b..54fbab14c 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -18,8 +18,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/neilotoole/slogt" - "github.com/neilotoole/sq/cli" "github.com/neilotoole/sq/cli/buildinfo" "github.com/neilotoole/sq/cli/config" @@ -43,6 +41,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/libsq/core/sqlz" @@ -110,7 +109,7 @@ type Helper struct { func New(t testing.TB, opts ...Option) *Helper { h := &Helper{ T: t, - Log: slogt.New(t), + Log: lgt.New(t), Cleanup: cleanup.New(), dbOpenTimeout: defaultDBOpenTimeout, } From c0097ef35d30901cf20283e8a1fcb0b812f18394 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 14 Dec 2023 12:23:22 -0700 Subject: [PATCH 120/195] tuning logging --- libsq/core/ioz/httpz/opts.go | 2 -- libsq/core/lg/devlog/tint/handler.go | 10 ++++++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index ba89ba22e..1e8d37d85 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -126,8 +126,6 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { select { case <-ctx.Done(): case <-t.C: - log := lg.FromContext(ctx) - _ = log lg.FromContext(ctx).Warn("HTTP header response not received within timeout", lga.Timeout, timeout, lga.URL, req.URL.String()) cancelFn(context.DeadlineExceeded) diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 2f868054f..e12894287 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -74,12 +74,15 @@ const ( ansiFaint = "\033[2m" ansiResetFaint = "\033[22m" ansiBrightRed = "\033[91m" + ansiBrightRedBold = "\033[1;91m" ansiBrightGreen = "\033[92m" ansiBrightGreenBold = "\033[1;92m" ansiBrightYellow = "\033[93m" ansiBlue = "\033[34m" ansiBrightBlue = "\033[94m" ansiBrightRedFaint = "\033[91;2m" + ansiAttrColor = "\033[36;2m" + // ansiAttrColor = "\033[35;2m" ) const errKey = "err" @@ -229,7 +232,7 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { case slog.LevelWarn: msgColor = ansiBrightYellow case slog.LevelError: - msgColor = ansiBrightRed + msgColor = ansiBrightRedBold case slog.LevelInfo: msgColor = ansiBrightGreenBold } @@ -318,7 +321,7 @@ func (h *handler) appendLevel(buf *buffer, level slog.Level) { appendLevelDelta(buf, level-slog.LevelWarn) buf.WriteStringIf(!h.noColor, ansiReset) default: - buf.WriteStringIf(!h.noColor, ansiBrightRed) + buf.WriteStringIf(!h.noColor, ansiBrightRedBold) buf.WriteString("ERR") appendLevelDelta(buf, level-slog.LevelError) buf.WriteStringIf(!h.noColor, ansiReset) @@ -381,7 +384,10 @@ func (h *handler) appendAttr(buf *buffer, attr slog.Attr, groupsPrefix string, g buf.WriteByte(' ') } else { h.appendKey(buf, attr.Key, groupsPrefix) + buf.WriteStringIf(!h.noColor, ansiAttrColor) h.appendValue(buf, attr.Value, true) + buf.WriteStringIf(!h.noColor, ansiReset) + buf.WriteByte(' ') } } From 474009969e3e6b9e3642a33f26220b2d3c92a44e Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 14 Dec 2023 13:23:11 -0700 Subject: [PATCH 121/195] stacktraces working decently well --- cli/cmd_x.go | 12 +++++- cli/error.go | 7 ++-- cli/output/jsonw/errorwriter.go | 2 +- cli/output/tablew/errorwriter.go | 2 +- libsq/core/errz/errz_test.go | 11 ++++++ libsq/core/errz/stack.go | 37 +++++++++++++++-- libsq/core/lg/devlog/tint/handler.go | 59 ++++++++++++++++++++++++++-- libsq/core/lg/lga/lga.go | 1 + 8 files changed, 119 insertions(+), 12 deletions(-) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index ec21cb0bb..23c2ddabe 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -3,6 +3,7 @@ package cli import ( "bufio" "fmt" + "github.com/neilotoole/sq/libsq/core/errz" "net/url" "os" "time" @@ -134,7 +135,16 @@ func newXDownloadCmd() *cobra.Command { Short: "Download a file", Hidden: true, Args: cobra.ExactArgs(1), - RunE: execXDownloadCmd, + //RunE: execXDownloadCmd, + RunE: func(cmd *cobra.Command, args []string) error { + err1 := errz.New("inner huzzah") + time.Sleep(time.Nanosecond) + err2 := errz.Wrap(err1, "outer huzzah") + time.Sleep(time.Nanosecond) + err3 := errz.Wrap(err2, "outer huzzah") + + return err3 + }, Example: ` $ sq x download https://sq.io/testdata/actor.csv # Download a big-ass file diff --git a/cli/error.go b/cli/error.go index d42dd9183..9639a5688 100644 --- a/cli/error.go +++ b/cli/error.go @@ -46,13 +46,14 @@ func printError(ctx context.Context, ru *run.Run, err error) { cmdName = cmd.Name() } + logFn := log.Error if errz.IsErrContext(err) { // If it's a context error, e.g. the user cancelled, we'll log it as // a warning instead of as an error. - log.Warn("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName) - } else { - log.Error("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName) + logFn = log.Warn + //log.Warn("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName) } + logFn("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName, lga.Stack, errz.Stacks(err)) err = humanizeContextErr(err) wrtrs := ru.Writers if wrtrs != nil && wrtrs.Error != nil { diff --git a/cli/output/jsonw/errorwriter.go b/cli/output/jsonw/errorwriter.go index 929af3430..d8ad1e15e 100644 --- a/cli/output/jsonw/errorwriter.go +++ b/cli/output/jsonw/errorwriter.go @@ -31,7 +31,7 @@ func (w *errorWriter) Error(err error) { } else { errMsg = err.Error() if w.pr.Verbose { - for _, st := range errz.Stack(err) { + for _, st := range errz.Stacks(err) { s := fmt.Sprintf("%+v", st) stack = append(stack, s) } diff --git a/cli/output/tablew/errorwriter.go b/cli/output/tablew/errorwriter.go index 0ba041012..523060bd7 100644 --- a/cli/output/tablew/errorwriter.go +++ b/cli/output/tablew/errorwriter.go @@ -28,7 +28,7 @@ func (w *errorWriter) Error(err error) { return } - stacks := errz.Stack(err) + stacks := errz.Stacks(err) for i, stack := range stacks { if i > 0 { fmt.Fprintln(w.w) diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 130294add..2430adc0d 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -121,3 +121,14 @@ func TestAs(t *testing.T) { require.NotNil(t, pathErr) require.Equal(t, fp, pathErr.Path) } + +func TestStackTrace(t *testing.T) { + err := errz.New("huzzah") + + tracer, ok := err.(errz.StackTracer) + require.True(t, ok) + require.NotNil(t, tracer) + tr := tracer.StackTrace() + t.Logf("stack trace:%+v", tr) + +} diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index 0884125ac..64cb235bd 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "log/slog" "path" "runtime" "strconv" @@ -139,6 +140,15 @@ func (st StackTrace) formatSlice(s fmt.State, verb rune) { _, _ = io.WriteString(s, "]") } +// LogValue implements slog.LogValuer. +func (st StackTrace) LogValue() slog.Value { + if len(st) == 0 { + return slog.Value{} + } + + return slog.StringValue(fmt.Sprintf("%+v", st)) +} + // stack represents a stack of program counters. type stack []uintptr @@ -179,11 +189,28 @@ func funcname(name string) string { return name[i+1:] } -// Stack returns any stack trace(s) attached to err. If err -// has been wrapped more than once, there may be multiple stack traces. +// Stack returns the last of any stack trace(s) attached to err. +// If err has been wrapped more than once, there may be multiple stack traces. // Generally speaking, the final stack trace is the most interesting. // The returned StackTrace can be printed using fmt "%+v". -func Stack(err error) []StackTrace { +func Stack(err error) StackTrace { + if err == nil { + return nil + } + + stacks := Stacks(err) + if len(stacks) == 0 { + return nil + } + + // Return the final element of the slice + return stacks[len(stacks)-1] +} + +// Stacks returns any stack trace(s) attached to err. If err +// has been wrapped more than once, there may be multiple stack traces. +// Generally speaking, the final stack trace is the most interesting. +func Stacks(err error) []StackTrace { if err == nil { return nil } @@ -208,3 +235,7 @@ func Stack(err error) []StackTrace { return stacks } + +type StackTracer interface { + StackTrace() StackTrace +} diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index e12894287..550a4f8bd 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -56,6 +56,7 @@ import ( "context" "encoding" "fmt" + "github.com/neilotoole/sq/libsq/core/errz" "io" "log/slog" "path/filepath" @@ -81,8 +82,8 @@ const ( ansiBlue = "\033[34m" ansiBrightBlue = "\033[94m" ansiBrightRedFaint = "\033[91;2m" - ansiAttrColor = "\033[36;2m" - // ansiAttrColor = "\033[35;2m" + ansiAttr = "\033[36;2m" + ansiStack = "\033[0;35m" ) const errKey = "err" @@ -254,8 +255,16 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { buf.WriteString(h.attrsPrefix) } + const keyStack = "stack" + var stackAttrs []slog.Attr + // write attributes r.Attrs(func(attr slog.Attr) bool { + if attr.Key == keyStack { + // Special handling for stacktraces + stackAttrs = append(stackAttrs, attr) + return true + } h.appendAttr(buf, attr, h.groupPrefix, h.groups) return true }) @@ -265,6 +274,8 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { } (*buf)[len(*buf)-1] = '\n' // replace last space with newline + h.handleStackAttrs(buf, stackAttrs) + h.mu.Lock() defer h.mu.Unlock() @@ -272,6 +283,48 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { return err } +func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { + if len(attrs) == 0 { + return + } + var stacks []errz.StackTrace + for _, attr := range attrs { + v := attr.Value.Any() + switch v := v.(type) { + case errz.StackTrace: + stacks = append(stacks, v) + case []errz.StackTrace: + stacks = append(stacks, v...) + } + } + + var count int + for _, stack := range stacks { + if stack == nil { + continue + } + + v := fmt.Sprintf("%+v", stack) + v = strings.TrimSpace(v) + v = strings.ReplaceAll(v, "\n\t", "\n ") + if v == "" { + continue + } + + if count > 0 { + buf.WriteString("\n\n") + } + buf.WriteStringIf(!h.noColor, ansiStack) + buf.WriteString(v) + buf.WriteStringIf(!h.noColor, ansiReset) + count++ + } + if count > 0 { + buf.WriteByte('\n') + } + +} + func (h *handler) WithAttrs(attrs []slog.Attr) slog.Handler { if len(attrs) == 0 { return h @@ -384,7 +437,7 @@ func (h *handler) appendAttr(buf *buffer, attr slog.Attr, groupsPrefix string, g buf.WriteByte(' ') } else { h.appendKey(buf, attr.Key, groupsPrefix) - buf.WriteStringIf(!h.noColor, ansiAttrColor) + buf.WriteStringIf(!h.noColor, ansiAttr) h.appendValue(buf, attr.Value, true) buf.WriteStringIf(!h.noColor, ansiReset) diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index 03d9ef511..33a9a2544 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -49,6 +49,7 @@ const ( SQL = "sql" Src = "src" ScanType = "scan_type" + Stack = "stack" Schema = "schema" Table = "table" Target = "target" From 79e5245fa7c1175b763009fd2cc3668365df6b8d Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 14 Dec 2023 16:44:56 -0700 Subject: [PATCH 122/195] log test --- cli/error.go | 1 - libsq/core/errz/errors.go | 30 +++++++++++++++++++ libsq/core/errz/errz.go | 45 ++++++++++++++-------------- libsq/core/errz/stack.go | 34 ++++++++++++--------- libsq/core/lg/devlog/devlog.go | 6 ++++ libsq/core/lg/devlog/devlog_test.go | 16 ++++++++++ libsq/core/lg/devlog/tint/handler.go | 6 ++-- 7 files changed, 97 insertions(+), 41 deletions(-) create mode 100644 libsq/core/lg/devlog/devlog_test.go diff --git a/cli/error.go b/cli/error.go index 9639a5688..f64cd95f5 100644 --- a/cli/error.go +++ b/cli/error.go @@ -51,7 +51,6 @@ func printError(ctx context.Context, ru *run.Run, err error) { // If it's a context error, e.g. the user cancelled, we'll log it as // a warning instead of as an error. logFn = log.Warn - //log.Warn("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName) } logFn("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName, lga.Stack, errz.Stacks(err)) err = humanizeContextErr(err) diff --git a/libsq/core/errz/errors.go b/libsq/core/errz/errors.go index 9a85fbd77..08941d631 100644 --- a/libsq/core/errz/errors.go +++ b/libsq/core/errz/errors.go @@ -141,6 +141,21 @@ func (f *fundamental) Format(s fmt.State, verb rune) { } } +var _ StackTracer = (*fundamental)(nil) + +// StackTrace implements StackTracer. +func (f *fundamental) StackTrace() *StackTrace { + if f == nil || f.stack == nil { + return nil + } + + st := f.stack.stackTrace() + if st != nil { + st.Error = f + } + return st +} + // LogValue implements slog.LogValuer. func (f *fundamental) LogValue() slog.Value { return logValue(f) @@ -151,6 +166,21 @@ type withStack struct { *stack } +var _ StackTracer = (*withStack)(nil) + +// StackTrace implements StackTracer. +func (w *withStack) StackTrace() *StackTrace { + if w == nil || w.stack == nil { + return nil + } + + st := w.stack.stackTrace() + if st != nil { + st.Error = w + } + return st +} + func (w *withStack) Cause() error { return w.error } // Unwrap provides compatibility for Go 1.13 error chains. diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index 98813d95a..f72bfb70e 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -12,10 +12,8 @@ import ( "context" "errors" "fmt" - "log/slog" - "path/filepath" - "go.uber.org/multierr" + "log/slog" ) // Err annotates err with a stack trace at the point WithStack was called. @@ -40,8 +38,9 @@ var Combine = multierr.Combine var Errors = multierr.Errors // logValue return a slog.Value for err. +// Deprecated: Are we using logValue? func logValue(err error) slog.Value { - if err == nil { + if err == nil { return slog.Value{} } @@ -55,25 +54,25 @@ func logValue(err error) slog.Value { causeAttr := slog.String("cause", c.Error()) typeAttr := slog.String("type", fmt.Sprintf("%T", c)) - if ws, ok := err.(*withStack); ok { //nolint:errorlint - st := ws.stack.StackTrace() - - if len(st) > 0 { - f := st[0] - file := f.file() - funcName := f.name() - if funcName != unknown { - fp := filepath.Join(filepath.Base(filepath.Dir(file)), filepath.Base(file)) - return slog.GroupValue( - msgAttr, - causeAttr, - typeAttr, - slog.String("func", funcName), - slog.String("source", fmt.Sprintf("%s:%d", fp, f.line())), - ) - } - } - } + //if ws, ok := err.(*withStack); ok { //nolint:errorlint + // st := ws.stack.stackTrace() + // + // if st != nil && len(st.Frames) > 0 { + // f := st.Frames[0] + // file := f.file() + // funcName := f.name() + // if funcName != unknown { + // fp := filepath.Join(filepath.Base(filepath.Dir(file)), filepath.Base(file)) + // return slog.GroupValue( + // msgAttr, + // causeAttr, + // typeAttr, + // slog.String("func", funcName), + // slog.String("source", fmt.Sprintf("%s:%d", fp, f.line())), + // ) + // } + // } + //} return slog.GroupValue(msgAttr, causeAttr, typeAttr) } diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index 64cb235bd..a9bf6fc21 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -98,7 +98,10 @@ func (f Frame) MarshalText() ([]byte, error) { } // StackTrace is stack of Frames from innermost (newest) to outermost (oldest). -type StackTrace []Frame +type StackTrace struct { + Error error + Frames []Frame +} // Format formats the stack of Frames according to the fmt.Formatter interface. // @@ -108,17 +111,17 @@ type StackTrace []Frame // Format accepts flags that alter the printing of some verbs, as follows: // // %+v Prints filename, function, and line number for each Frame in the stack. -func (st StackTrace) Format(s fmt.State, verb rune) { +func (st *StackTrace) Format(s fmt.State, verb rune) { switch verb { case 'v': switch { case s.Flag('+'): - for _, f := range st { + for _, f := range st.Frames { _, _ = io.WriteString(s, "\n") f.Format(s, verb) } case s.Flag('#'): - fmt.Fprintf(s, "%#v", []Frame(st)) + fmt.Fprintf(s, "%#v", []Frame(st.Frames)) default: st.formatSlice(s, verb) } @@ -129,9 +132,9 @@ func (st StackTrace) Format(s fmt.State, verb rune) { // formatSlice will format this StackTrace into the given buffer as a slice of // Frame, only valid when called with '%s' or '%v'. -func (st StackTrace) formatSlice(s fmt.State, verb rune) { +func (st *StackTrace) formatSlice(s fmt.State, verb rune) { _, _ = io.WriteString(s, "[") - for i, f := range st { + for i, f := range st.Frames { if i > 0 { _, _ = io.WriteString(s, " ") } @@ -141,8 +144,8 @@ func (st StackTrace) formatSlice(s fmt.State, verb rune) { } // LogValue implements slog.LogValuer. -func (st StackTrace) LogValue() slog.Value { - if len(st) == 0 { +func (st *StackTrace) LogValue() slog.Value { + if st == nil || len(st.Frames) == 0 { return slog.Value{} } @@ -153,6 +156,9 @@ func (st StackTrace) LogValue() slog.Value { type stack []uintptr func (s *stack) Format(st fmt.State, verb rune) { + if s == nil { + fmt.Fprintf(st, "") + } switch verb { //nolint:gocritic case 'v': switch { //nolint:gocritic @@ -165,12 +171,12 @@ func (s *stack) Format(st fmt.State, verb rune) { } } -func (s *stack) StackTrace() StackTrace { +func (s *stack) stackTrace() *StackTrace { f := make([]Frame, len(*s)) for i := 0; i < len(f); i++ { f[i] = Frame((*s)[i]) } - return f + return &StackTrace{Frames: f} } func callers() *stack { @@ -193,7 +199,7 @@ func funcname(name string) string { // If err has been wrapped more than once, there may be multiple stack traces. // Generally speaking, the final stack trace is the most interesting. // The returned StackTrace can be printed using fmt "%+v". -func Stack(err error) StackTrace { +func Stack(err error) *StackTrace { if err == nil { return nil } @@ -210,12 +216,12 @@ func Stack(err error) StackTrace { // Stacks returns any stack trace(s) attached to err. If err // has been wrapped more than once, there may be multiple stack traces. // Generally speaking, the final stack trace is the most interesting. -func Stacks(err error) []StackTrace { +func Stacks(err error) []*StackTrace { if err == nil { return nil } - var stacks []StackTrace + var stacks []*StackTrace for { if err == nil { @@ -237,5 +243,5 @@ func Stacks(err error) []StackTrace { } type StackTracer interface { - StackTrace() StackTrace + StackTrace() *StackTrace } diff --git a/libsq/core/lg/devlog/devlog.go b/libsq/core/lg/devlog/devlog.go index 4bba3c060..3ed61c5a2 100644 --- a/libsq/core/lg/devlog/devlog.go +++ b/libsq/core/lg/devlog/devlog.go @@ -22,6 +22,12 @@ func NewHandler(w io.Writer, lvl slog.Leveler) slog.Handler { switch a.Key { case "pid": return slog.Attr{} + case "error": + if _, ok := a.Value.Any().(error); ok { + a.Key = "e" + } + a.Key = "wussah" + return a default: return a } diff --git a/libsq/core/lg/devlog/devlog_test.go b/libsq/core/lg/devlog/devlog_test.go new file mode 100644 index 000000000..229896cd5 --- /dev/null +++ b/libsq/core/lg/devlog/devlog_test.go @@ -0,0 +1,16 @@ +package devlog_test + +import ( + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgt" + "testing" +) + +func TestDevlog(t *testing.T) { + log := lgt.New(t) + //log.Debug("huzzah") + err := errz.New("oh noes") + log.Error("bah", lga.Err, err, lga.Stack, errz.Stacks(err)) + +} diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 550a4f8bd..2435f8312 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -287,13 +287,13 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { if len(attrs) == 0 { return } - var stacks []errz.StackTrace + var stacks []*errz.StackTrace for _, attr := range attrs { v := attr.Value.Any() switch v := v.(type) { - case errz.StackTrace: + case *errz.StackTrace: stacks = append(stacks, v) - case []errz.StackTrace: + case []*errz.StackTrace: stacks = append(stacks, v...) } } From f371b21590a859da260b0d2df7626e4fa50d941b Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 14 Dec 2023 16:45:31 -0700 Subject: [PATCH 123/195] log merge --- libsq/core/lg/devlog/tint/handler.go | 107 ++------------------------- 1 file changed, 8 insertions(+), 99 deletions(-) diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 2435f8312..0cc319056 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -56,34 +56,25 @@ import ( "context" "encoding" "fmt" - "github.com/neilotoole/sq/libsq/core/errz" "io" "log/slog" "path/filepath" "runtime" "strconv" - "strings" "sync" "time" "unicode" ) // ANSI modes -// See: https://gist.github.com/JBlond/2fea43a3049b38287e5e9cefc87b2124 const ( - ansiReset = "\033[0m" - ansiFaint = "\033[2m" - ansiResetFaint = "\033[22m" - ansiBrightRed = "\033[91m" - ansiBrightRedBold = "\033[1;91m" - ansiBrightGreen = "\033[92m" - ansiBrightGreenBold = "\033[1;92m" - ansiBrightYellow = "\033[93m" - ansiBlue = "\033[34m" - ansiBrightBlue = "\033[94m" - ansiBrightRedFaint = "\033[91;2m" - ansiAttr = "\033[36;2m" - ansiStack = "\033[0;35m" + ansiReset = "\033[0m" + ansiFaint = "\033[2m" + ansiResetFaint = "\033[22m" + ansiBrightRed = "\033[91m" + ansiBrightGreen = "\033[92m" + ansiBrightYellow = "\033[93m" + ansiBrightRedFaint = "\033[91;2m" ) const errKey = "err" @@ -226,27 +217,12 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { } } - msgColor := ansiBrightGreen - switch r.Level { - case slog.LevelDebug: - msgColor = ansiBrightGreen - case slog.LevelWarn: - msgColor = ansiBrightYellow - case slog.LevelError: - msgColor = ansiBrightRedBold - case slog.LevelInfo: - msgColor = ansiBrightGreenBold - } // write message if rep == nil { - buf.WriteStringIf(!h.noColor, msgColor) buf.WriteString(r.Message) - buf.WriteStringIf(!h.noColor, ansiReset) buf.WriteByte(' ') } else if a := rep(nil /* groups */, slog.String(slog.MessageKey, r.Message)); a.Key != "" { - buf.WriteStringIf(!h.noColor, msgColor) h.appendValue(buf, a.Value, false) - buf.WriteStringIf(!h.noColor, ansiReset) buf.WriteByte(' ') } @@ -255,16 +231,8 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { buf.WriteString(h.attrsPrefix) } - const keyStack = "stack" - var stackAttrs []slog.Attr - // write attributes r.Attrs(func(attr slog.Attr) bool { - if attr.Key == keyStack { - // Special handling for stacktraces - stackAttrs = append(stackAttrs, attr) - return true - } h.appendAttr(buf, attr, h.groupPrefix, h.groups) return true }) @@ -274,8 +242,6 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { } (*buf)[len(*buf)-1] = '\n' // replace last space with newline - h.handleStackAttrs(buf, stackAttrs) - h.mu.Lock() defer h.mu.Unlock() @@ -283,48 +249,6 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { return err } -func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { - if len(attrs) == 0 { - return - } - var stacks []*errz.StackTrace - for _, attr := range attrs { - v := attr.Value.Any() - switch v := v.(type) { - case *errz.StackTrace: - stacks = append(stacks, v) - case []*errz.StackTrace: - stacks = append(stacks, v...) - } - } - - var count int - for _, stack := range stacks { - if stack == nil { - continue - } - - v := fmt.Sprintf("%+v", stack) - v = strings.TrimSpace(v) - v = strings.ReplaceAll(v, "\n\t", "\n ") - if v == "" { - continue - } - - if count > 0 { - buf.WriteString("\n\n") - } - buf.WriteStringIf(!h.noColor, ansiStack) - buf.WriteString(v) - buf.WriteStringIf(!h.noColor, ansiReset) - count++ - } - if count > 0 { - buf.WriteByte('\n') - } - -} - func (h *handler) WithAttrs(attrs []slog.Attr) slog.Handler { if len(attrs) == 0 { return h @@ -374,7 +298,7 @@ func (h *handler) appendLevel(buf *buffer, level slog.Level) { appendLevelDelta(buf, level-slog.LevelWarn) buf.WriteStringIf(!h.noColor, ansiReset) default: - buf.WriteStringIf(!h.noColor, ansiBrightRedBold) + buf.WriteStringIf(!h.noColor, ansiBrightRed) buf.WriteString("ERR") appendLevelDelta(buf, level-slog.LevelError) buf.WriteStringIf(!h.noColor, ansiReset) @@ -393,18 +317,6 @@ func appendLevelDelta(buf *buffer, delta slog.Level) { func (h *handler) appendSource(buf *buffer, src *slog.Source) { dir, file := filepath.Split(src.File) - fn := src.Function - parts := strings.Split(src.Function, "/") - if len(parts) > 0 { - fn = parts[len(parts)-1] - } - - if fn != "" { - buf.WriteStringIf(!h.noColor, ansiBlue) - buf.WriteString(fn) - buf.WriteStringIf(!h.noColor, ansiReset) - buf.WriteByte(' ') - } buf.WriteStringIf(!h.noColor, ansiFaint) buf.WriteString(filepath.Join(filepath.Base(dir), file)) buf.WriteByte(':') @@ -437,10 +349,7 @@ func (h *handler) appendAttr(buf *buffer, attr slog.Attr, groupsPrefix string, g buf.WriteByte(' ') } else { h.appendKey(buf, attr.Key, groupsPrefix) - buf.WriteStringIf(!h.noColor, ansiAttr) h.appendValue(buf, attr.Value, true) - buf.WriteStringIf(!h.noColor, ansiReset) - buf.WriteByte(' ') } } From e05f6a332bb5e4a25867f9410072437e9632de43 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 14 Dec 2023 17:24:31 -0700 Subject: [PATCH 124/195] tuning devlog --- libsq/core/errz/errz.go | 43 +++----- libsq/core/lg/devlog/devlog.go | 5 +- libsq/core/lg/devlog/devlog_test.go | 38 ++++++- libsq/core/lg/devlog/tint/handler.go | 148 +++++++++++++++++++++++---- 4 files changed, 180 insertions(+), 54 deletions(-) diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index f72bfb70e..4f4784f25 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -38,9 +38,8 @@ var Combine = multierr.Combine var Errors = multierr.Errors // logValue return a slog.Value for err. -// Deprecated: Are we using logValue? func logValue(err error) slog.Value { - if err == nil { + if err == nil { return slog.Value{} } @@ -50,31 +49,21 @@ func logValue(err error) slog.Value { return slog.Value{} } - msgAttr := slog.String("msg", err.Error()) - causeAttr := slog.String("cause", c.Error()) - typeAttr := slog.String("type", fmt.Sprintf("%T", c)) - - //if ws, ok := err.(*withStack); ok { //nolint:errorlint - // st := ws.stack.stackTrace() - // - // if st != nil && len(st.Frames) > 0 { - // f := st.Frames[0] - // file := f.file() - // funcName := f.name() - // if funcName != unknown { - // fp := filepath.Join(filepath.Base(filepath.Dir(file)), filepath.Base(file)) - // return slog.GroupValue( - // msgAttr, - // causeAttr, - // typeAttr, - // slog.String("func", funcName), - // slog.String("source", fmt.Sprintf("%s:%d", fp, f.line())), - // ) - // } - // } - //} - - return slog.GroupValue(msgAttr, causeAttr, typeAttr) + attrs := []slog.Attr{slog.String("msg", err.Error())} + if !errors.Is(c, err) { + attrs = append(attrs, + slog.String("cause", c.Error()), + slog.String("type", fmt.Sprintf("%T", c)), + ) + + // If there's a cause c, "type" will be the type of c. + } else { + // If there's no cause, "type" will be the type of err. + // It's a bit wonky, but probably the most useful thing to show. + attrs = append(attrs, slog.String("type", fmt.Sprintf("%T", err))) + } + + return slog.GroupValue(attrs...) } // IsErrContext returns true if err is context.Canceled or context.DeadlineExceeded. diff --git a/libsq/core/lg/devlog/devlog.go b/libsq/core/lg/devlog/devlog.go index 3ed61c5a2..fd4fa020a 100644 --- a/libsq/core/lg/devlog/devlog.go +++ b/libsq/core/lg/devlog/devlog.go @@ -23,10 +23,7 @@ func NewHandler(w io.Writer, lvl slog.Leveler) slog.Handler { case "pid": return slog.Attr{} case "error": - if _, ok := a.Value.Any().(error); ok { - a.Key = "e" - } - a.Key = "wussah" + a.Key = "err" return a default: return a diff --git a/libsq/core/lg/devlog/devlog_test.go b/libsq/core/lg/devlog/devlog_test.go index 229896cd5..c8756d08b 100644 --- a/libsq/core/lg/devlog/devlog_test.go +++ b/libsq/core/lg/devlog/devlog_test.go @@ -4,13 +4,49 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgt" + "log/slog" + "os" "testing" ) func TestDevlog(t *testing.T) { + log := lgt.New(t) //log.Debug("huzzah") err := errz.New("oh noes") - log.Error("bah", lga.Err, err, lga.Stack, errz.Stacks(err)) + //stack := errs.Stacks(err) + // lga.Stack, errz.Stacks(err) + log.Error("bah", lga.Err, err) + +} + +func TestDevlogTextHandler(t *testing.T) { + o := &slog.HandlerOptions{ + ReplaceAttr: ReplaceAttr, + } + + h := slog.NewTextHandler(os.Stdout, o) + log := slog.New(h) + //log := lgt.New(t) + //log.Debug("huzzah") + err := errz.New("oh noes") + //stack := errs.Stacks(err) + // lga.Stack, errz.Stacks(err) + log.Error("bah", lga.Err, err) + +} +func ReplaceAttr(groups []string, a slog.Attr) slog.Attr { + switch a.Key { + case "pid": + return slog.Attr{} + case "error": + if _, ok := a.Value.Any().(error); ok { + a.Key = "e" + } + a.Key = "wussah" + return a + default: + return a + } } diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 0cc319056..ce26471a4 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -56,25 +56,36 @@ import ( "context" "encoding" "fmt" + "github.com/neilotoole/sq/libsq/core/errz" "io" "log/slog" "path/filepath" "runtime" "strconv" + "strings" "sync" "time" "unicode" ) // ANSI modes +// See: https://gist.github.com/JBlond/2fea43a3049b38287e5e9cefc87b2124 const ( - ansiReset = "\033[0m" - ansiFaint = "\033[2m" - ansiResetFaint = "\033[22m" - ansiBrightRed = "\033[91m" - ansiBrightGreen = "\033[92m" - ansiBrightYellow = "\033[93m" - ansiBrightRedFaint = "\033[91;2m" + ansiAttr = "\033[36;2m" + ansiBlue = "\033[34m" + ansiBrightBlue = "\033[94m" + ansiBrightGreen = "\033[92m" + ansiBrightGreenBold = "\033[1;92m" + ansiBrightRed = "\033[91m" + ansiBrightRedBold = "\033[1;91m" + ansiBrightRedFaint = "\033[91;2m" + ansiBrightYellow = "\033[93m" + ansiFaint = "\033[2m" + ansiReset = "\033[0m" + ansiResetFaint = "\033[22m" + ansiStack = "\033[0;35m" + ansiYellowBold = "\033[1;33m" + ansiStackErr = ansiYellowBold ) const errKey = "err" @@ -217,12 +228,27 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { } } + msgColor := ansiBrightGreen + switch r.Level { + case slog.LevelDebug: + msgColor = ansiBrightGreen + case slog.LevelWarn: + msgColor = ansiBrightYellow + case slog.LevelError: + msgColor = ansiBrightRedBold + case slog.LevelInfo: + msgColor = ansiBrightGreenBold + } // write message if rep == nil { + buf.WriteStringIf(!h.noColor, msgColor) buf.WriteString(r.Message) + buf.WriteStringIf(!h.noColor, ansiReset) buf.WriteByte(' ') } else if a := rep(nil /* groups */, slog.String(slog.MessageKey, r.Message)); a.Key != "" { + buf.WriteStringIf(!h.noColor, msgColor) h.appendValue(buf, a.Value, false) + buf.WriteStringIf(!h.noColor, ansiReset) buf.WriteByte(' ') } @@ -231,8 +257,16 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { buf.WriteString(h.attrsPrefix) } + const keyStack = "stack" + var stackAttrs []slog.Attr + // write attributes r.Attrs(func(attr slog.Attr) bool { + if attr.Key == keyStack { + // Special handling for stacktraces + stackAttrs = append(stackAttrs, attr) + return true + } h.appendAttr(buf, attr, h.groupPrefix, h.groups) return true }) @@ -242,6 +276,8 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { } (*buf)[len(*buf)-1] = '\n' // replace last space with newline + h.handleStackAttrs(buf, stackAttrs) + h.mu.Lock() defer h.mu.Unlock() @@ -249,6 +285,59 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { return err } +func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { + if len(attrs) == 0 { + return + } + var stacks []*errz.StackTrace + for _, attr := range attrs { + v := attr.Value.Any() + switch v := v.(type) { + case *errz.StackTrace: + stacks = append(stacks, v) + case []*errz.StackTrace: + stacks = append(stacks, v...) + } + } + + var count int + for _, stack := range stacks { + if stack == nil { + continue + } + + v := fmt.Sprintf("%+v", stack) + v = strings.TrimSpace(v) + v = strings.ReplaceAll(v, "\n\t", "\n ") + if v == "" { + continue + } + + if count > 0 { + buf.WriteString("\n") + } + + if stack.Error != nil { + buf.WriteStringIf(!h.noColor, ansiStackErr) + buf.WriteString(stack.Error.Error()) + buf.WriteStringIf(!h.noColor, ansiReset) + buf.WriteByte('\n') + } + lines := strings.Split(v, "\n") + for _, line := range lines { + buf.WriteStringIf(!h.noColor, ansiStack) + buf.WriteString(line) + buf.WriteStringIf(!h.noColor, ansiReset) + buf.WriteByte('\n') + } + count++ + } + if count > 0 { + buf.WriteByte('\n') + } + +} + func (h *handler) WithAttrs(attrs []slog.Attr) slog.Handler { if len(attrs) == 0 { return h @@ -298,7 +387,7 @@ func (h *handler) appendLevel(buf *buffer, level slog.Level) { appendLevelDelta(buf, level-slog.LevelWarn) buf.WriteStringIf(!h.noColor, ansiReset) default: - buf.WriteStringIf(!h.noColor, ansiBrightRed) + buf.WriteStringIf(!h.noColor, ansiBrightRedBold) buf.WriteString("ERR") appendLevelDelta(buf, level-slog.LevelError) buf.WriteStringIf(!h.noColor, ansiReset) @@ -317,6 +406,18 @@ func appendLevelDelta(buf *buffer, delta slog.Level) { func (h *handler) appendSource(buf *buffer, src *slog.Source) { dir, file := filepath.Split(src.File) + fn := src.Function + parts := strings.Split(src.Function, "/") + if len(parts) > 0 { + fn = parts[len(parts)-1] + } + + if fn != "" { + buf.WriteStringIf(!h.noColor, ansiBlue) + buf.WriteString(fn) + buf.WriteStringIf(!h.noColor, ansiReset) + buf.WriteByte(' ') + } buf.WriteStringIf(!h.noColor, ansiFaint) buf.WriteString(filepath.Join(filepath.Base(dir), file)) buf.WriteByte(':') @@ -324,32 +425,35 @@ func (h *handler) appendSource(buf *buffer, src *slog.Source) { buf.WriteStringIf(!h.noColor, ansiReset) } -func (h *handler) appendAttr(buf *buffer, attr slog.Attr, groupsPrefix string, groups []string) { - attr.Value = attr.Value.Resolve() - if rep := h.replaceAttr; rep != nil && attr.Value.Kind() != slog.KindGroup { - attr = rep(groups, attr) - attr.Value = attr.Value.Resolve() +func (h *handler) appendAttr(buf *buffer, a slog.Attr, groupsPrefix string, groups []string) { + if rep := h.replaceAttr; rep != nil && a.Value.Kind() != slog.KindGroup { + a.Value = a.Value.Resolve() + a = rep(groups, a) } + a.Value = a.Value.Resolve() - if attr.Equal(slog.Attr{}) { + if a.Equal(slog.Attr{}) { return } - if attr.Value.Kind() == slog.KindGroup { - if attr.Key != "" { - groupsPrefix += attr.Key + "." - groups = append(groups, attr.Key) + if a.Value.Kind() == slog.KindGroup { + if a.Key != "" { + groupsPrefix += a.Key + "." + groups = append(groups, a.Key) } - for _, groupAttr := range attr.Value.Group() { + for _, groupAttr := range a.Value.Group() { h.appendAttr(buf, groupAttr, groupsPrefix, groups) } - } else if err, ok := attr.Value.Any().(tintError); ok { + } else if err, ok := a.Value.Any().(tintError); ok { // append tintError h.appendTintError(buf, err, groupsPrefix) buf.WriteByte(' ') } else { - h.appendKey(buf, attr.Key, groupsPrefix) - h.appendValue(buf, attr.Value, true) + h.appendKey(buf, a.Key, groupsPrefix) + buf.WriteStringIf(!h.noColor, ansiAttr) + h.appendValue(buf, a.Value, true) + buf.WriteStringIf(!h.noColor, ansiReset) + buf.WriteByte(' ') } } From a776b3ce2ad5a8d9e8ad275a84d52282dd44a793 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 14 Dec 2023 17:44:07 -0700 Subject: [PATCH 125/195] tuning devlog --- libsq/core/lg/devlog/tint/handler.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index ce26471a4..0c184920c 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -321,6 +321,10 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { buf.WriteStringIf(!h.noColor, ansiStackErr) buf.WriteString(stack.Error.Error()) buf.WriteStringIf(!h.noColor, ansiReset) + buf.WriteByte(' ') + buf.WriteStringIf(!h.noColor, ansiFaint) + buf.WriteStringIf(!h.noColor, fmt.Sprintf("%T", stack.Error)) + buf.WriteStringIf(!h.noColor, ansiResetFaint) buf.WriteByte('\n') } lines := strings.Split(v, "\n") From 0c2dbf177b17461c17df264b66fe974a42bb029d Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 15 Dec 2023 11:20:04 -0700 Subject: [PATCH 126/195] cleanup --- cli/cli.go | 3 +- cli/cmd_x.go | 27 +++++----- cli/error.go | 66 ++++++++++++++--------- libsq/core/errz/errz.go | 13 ++--- libsq/core/errz/errz_test.go | 1 - libsq/core/errz/stack.go | 5 +- libsq/core/ioz/contextio/contextio.go | 50 ++++++++++------- libsq/core/ioz/download/cache.go | 27 ++++------ libsq/core/ioz/download/download.go | 8 ++- libsq/core/ioz/httpz/httpz.go | 2 + libsq/core/ioz/httpz/httpz_test.go | 2 +- libsq/core/ioz/httpz/opts.go | 77 +++++++++++++++++++-------- libsq/core/lg/devlog/devlog_test.go | 20 ++++--- libsq/core/lg/devlog/tint/handler.go | 18 +++++-- libsq/core/loz/loz.go | 24 +++++++++ libsq/source/files.go | 2 +- 16 files changed, 219 insertions(+), 126 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index 021dbd120..5f00510fc 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -50,7 +50,8 @@ const ( ) // errNoMsg is a sentinel error indicating that a command -// has failed, but that no error message should be printed. +// has failed (and thus the program should exit with a non-zero +// code), but no error message should be printed. // This is useful in the case where any error information may // already have been printed as part of the command output. var errNoMsg = errors.New("") diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 23c2ddabe..94850f36b 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -3,7 +3,6 @@ package cli import ( "bufio" "fmt" - "github.com/neilotoole/sq/libsq/core/errz" "net/url" "os" "time" @@ -135,16 +134,16 @@ func newXDownloadCmd() *cobra.Command { Short: "Download a file", Hidden: true, Args: cobra.ExactArgs(1), - //RunE: execXDownloadCmd, - RunE: func(cmd *cobra.Command, args []string) error { - err1 := errz.New("inner huzzah") - time.Sleep(time.Nanosecond) - err2 := errz.Wrap(err1, "outer huzzah") - time.Sleep(time.Nanosecond) - err3 := errz.Wrap(err2, "outer huzzah") - - return err3 - }, + RunE: execXDownloadCmd, + //RunE: func(cmd *cobra.Command, args []string) error { + // err1 := errz.New("inner huzzah") + // time.Sleep(time.Nanosecond) + // err2 := errz.Wrap(err1, "outer huzzah") + // time.Sleep(time.Nanosecond) + // err3 := errz.Wrap(err2, "outer huzzah") + // + // return err3 + //}, Example: ` $ sq x download https://sq.io/testdata/actor.csv # Download a big-ass file @@ -172,7 +171,11 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { return err } - c := httpz.NewClient(httpz.DefaultUserAgent, httpz.OptRequestTimeout(time.Second*5)) + c := httpz.NewClient( + httpz.DefaultUserAgent, + //httpz.OptRequestTimeout(time.Second*2), + httpz.OptHeaderTimeout(time.Millisecond), + ) dl, err := download.New(fakeSrc.Handle, c, u.String(), cacheDir) if err != nil { return err diff --git a/cli/error.go b/cli/error.go index f64cd95f5..98a4a461f 100644 --- a/cli/error.go +++ b/cli/error.go @@ -4,10 +4,11 @@ import ( "context" "errors" "fmt" - "os" - "github.com/spf13/cobra" "github.com/spf13/pflag" + "net/url" + "os" + "strings" "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/output/format" @@ -37,36 +38,28 @@ func printError(ctx context.Context, ru *run.Run, err error) { return } + var cmdName = "unknown" var cmd *cobra.Command - if ru != nil { + if ru != nil && ru.Cmd != nil { cmd = ru.Cmd + cmdName = ru.Cmd.Name() + } - cmdName := "unknown" - if cmd != nil { - cmdName = cmd.Name() - } + log.Error("EXECUTION FAILED", + lga.Err, err, lga.Cmd, cmdName, lga.Stack, errz.Stacks(err)) - logFn := log.Error - if errz.IsErrContext(err) { - // If it's a context error, e.g. the user cancelled, we'll log it as - // a warning instead of as an error. - logFn = log.Warn - } - logFn("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName, lga.Stack, errz.Stacks(err)) - err = humanizeContextErr(err) - wrtrs := ru.Writers - if wrtrs != nil && wrtrs.Error != nil { + err = humanizeError(err) + if ru != nil { + if wrtrs := ru.Writers; wrtrs != nil && wrtrs.Error != nil { // If we have an errorWriter, we print to it // and return. wrtrs.Error.Error(err) return } - // Else we don't have an errorWriter, so we fall through + // Else we don't have an error writer, so we fall through } - err = humanizeContextErr(err) - // If we get this far, something went badly wrong in bootstrap // (probably the config is corrupt). // At this point, we could just print err to os.Stderr and be done. @@ -156,20 +149,45 @@ func panicOn(err error) { } } -// humanizeContextErr returns a friendlier error message -// for context errors. -func humanizeContextErr(err error) error { +// humanizeError wrangles an error to make it more human-friendly before +// printing to stderr. The returned err may be a different error from the +// one passed in. This should be the final step before printing an error; +// the original error should have already been logged. +func humanizeError(err error) error { if err == nil { return nil } + // Download timeout errors are typically wrapped in an url.Error, resulting + // in a message like: + // + // Get "https://example.com": http response header not received within 1ms timeout + // + // We want to trim off that prefix, but we only do that if there's a wrapped + // error beneath (which should be the case). + if errz.IsType[*url.Error](err) && errors.Is(err, context.DeadlineExceeded) { + if e := errors.Unwrap(err); e != nil { + err = e + } + } + switch { // Friendlier messages for context errors. default: case errors.Is(err, context.Canceled): err = errz.New("canceled") case errors.Is(err, context.DeadlineExceeded): - err = errz.New("timeout") + errMsg := err.Error() + deadlineMsg := context.DeadlineExceeded.Error() + if errMsg == deadlineMsg { + // For generic context.DeadlineExceeded errors, we + // just return "timeout". + err = errz.New("timeout") + } else { + // But if the error is a wrapped context.DeadlineExceeded, we + // trim off the ": context deadline exceeded" suffix. + return errz.New(strings.TrimSuffix(errMsg, ": "+deadlineMsg)) + } } return err diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index 4f4784f25..77b76f568 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -12,8 +12,9 @@ import ( "context" "errors" "fmt" - "go.uber.org/multierr" "log/slog" + + "go.uber.org/multierr" ) // Err annotates err with a stack trace at the point WithStack was called. @@ -79,13 +80,9 @@ func IsErrContext(err error) bool { return false } -// IsErrContextDeadlineExceeded returns true if err is context.DeadlineExceeded. -func IsErrContextDeadlineExceeded(err error) bool { - return errors.Is(err, context.DeadlineExceeded) -} - -// Tuple returns t and err, wrapping err with errz.Err. -func Tuple[T any](t T, err error) (T, error) { +// Return returns t, and err wrapped with [errz.Err]. +// This is useful for the common case of returning a value and an error. +func Return[T any](t T, err error) (T, error) { return t, Err(err) } diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 2430adc0d..f90cbef0b 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -130,5 +130,4 @@ func TestStackTrace(t *testing.T) { require.NotNil(t, tracer) tr := tracer.StackTrace() t.Logf("stack trace:%+v", tr) - } diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index a9bf6fc21..d93024d31 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -99,7 +99,10 @@ func (f Frame) MarshalText() ([]byte, error) { // StackTrace is stack of Frames from innermost (newest) to outermost (oldest). type StackTrace struct { - Error error + // Error is the error value that resulted in this stack trace. + Error error + + // Frames is the ordered list of frames that make up this stack trace. Frames []Frame } diff --git a/libsq/core/ioz/contextio/contextio.go b/libsq/core/ioz/contextio/contextio.go index a600e396f..82c37d428 100644 --- a/libsq/core/ioz/contextio/contextio.go +++ b/libsq/core/ioz/contextio/contextio.go @@ -20,10 +20,7 @@ package contextio import ( "context" - "errors" "io" - - "github.com/neilotoole/sq/libsq/core/errz" ) var _ io.Writer = (*writer)(nil) @@ -81,9 +78,11 @@ func NewWriter(ctx context.Context, w io.Writer) io.Writer { func (w *writer) Write(p []byte) (n int, err error) { select { case <-w.ctx.Done(): - return 0, w.ctx.Err() + return 0, cause(w.ctx, nil) default: - return w.w.Write(p) + n, err = w.w.Write(p) + err = cause(w.ctx, err) + return n, err } } @@ -96,7 +95,7 @@ func (w *writeCloser) Close() error { select { case <-w.ctx.Done(): - return w.ctx.Err() + return cause(w.ctx, nil) default: return closeErr } @@ -136,9 +135,11 @@ func NewReader(ctx context.Context, r io.Reader) io.Reader { func (r *reader) Read(p []byte) (n int, err error) { select { case <-r.ctx.Done(): - return 0, r.ctx.Err() + return 0, cause(r.ctx, nil) default: - return r.r.Read(p) + n, err = r.r.Read(p) + err = cause(r.ctx, err) + return n, err } } @@ -157,7 +158,7 @@ func (rc *readCloser) Close() error { select { case <-rc.ctx.Done(): - return rc.ctx.Err() + return cause(rc.ctx, nil) default: return closeErr @@ -174,11 +175,13 @@ func (w *copier) ReadFrom(r io.Reader) (n int64, err error) { } select { case <-w.ctx.Done(): - return 0, w.ctx.Err() + return 0, cause(w.ctx, nil) default: // The original Writer is not a ReaderFrom. // Let the Reader decide the chunk size. - return io.Copy(&w.writer, r) + n, err = io.Copy(&w.writer, r) + err = cause(w.ctx, err) + return n, err } } @@ -199,15 +202,24 @@ func (c *closer) Close() error { select { case <-c.ctx.Done(): - ctxErr := c.ctx.Err() - switch { - case closeErr == nil, - errz.IsErrContext(closeErr): - return ctxErr - default: - return errors.Join(ctxErr, closeErr) - } + return cause(c.ctx, nil) + default: return closeErr } } + +func cause(ctx context.Context, err error) error { + if err == nil { + return context.Cause(ctx) + } + + // err is non-nil + if ctx.Err() != err { + // err is not the context error, so err takes precedence. + return err + } + + // err is the context error. Return the cause. + return context.Cause(ctx) +} diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index 743a13d10..e59ca8297 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -296,7 +296,7 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, headerBytes, err := httputil.DumpResponse(resp, false) if err != nil { - return err + return errz.Err(err) } if _, err = ioz.WriteToFile(ctx, fpHeader, bytes.NewReader(headerBytes)); err != nil { @@ -312,32 +312,28 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, return err } - var cr io.Reader - if copyWrtr == nil { - cr = contextio.NewReader(ctx, resp.Body) - } else { - tr := io.TeeReader(resp.Body, copyWrtr) - cr = contextio.NewReader(ctx, tr) + var r io.Reader = resp.Body + if copyWrtr != nil { + // If copyWrtr is non-nil, we're copying the response body + // to that destination as well as to the cache file. + r = io.TeeReader(resp.Body, copyWrtr) } var written int64 - if written, err = io.Copy(cacheFile, cr); err != nil { - err = errz.Err(err) + if written, err = errz.Return(io.Copy(cacheFile, r)); err != nil { log.Error("Cache write: io.Copy failed", lga.Err, err) lg.WarnIfCloseError(log, msgCloseCacheBodyFile, cacheFile) cacheFile = nil return err } - if err = cacheFile.Close(); err != nil { + if err = errz.Err(cacheFile.Close()); err != nil { cacheFile = nil - err = errz.Err(err) return err } if copyWrtr != nil { - if err = copyWrtr.Close(); err != nil { - err = errz.Err(err) + if err = errz.Err(copyWrtr.Close()); err != nil { copyWrtr = nil return err } @@ -352,11 +348,6 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, return errz.Wrap(err, "failed to write checksum file for cache body") } - if resp.Body == nil { - resp.Body = http.NoBody - return nil - } - log.Info("Wrote HTTP response body to cache", lga.Size, written, lga.File, fpBody) return nil } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index f0006e8fb..4bb1511a5 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -302,11 +302,15 @@ func (dl *Download) get(req *http.Request, h Handler) { // do executes the request. func (dl *Download) do(req *http.Request) (*http.Response, error) { resp, err := dl.c.Do(req) - if err == nil && resp.Body != nil { + if err != nil { + return nil, err + } + + if resp.Body != nil && resp.Body != http.NoBody { r := progress.NewReader(req.Context(), dl.name+": download", resp.ContentLength, resp.Body) resp.Body = r.(io.ReadCloser) } - return resp, err + return resp, nil } // mustRequest creates a new request from dl.url. The url has already been diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index c8858eee5..20e7730b1 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -51,6 +51,8 @@ func NewClient(opts ...Opt) *http.Client { } c.Transport = tr + c.Transport = RoundTrip(c.Transport, contextCause()) + for i := range opts { if tf, ok := opts[i].(TripFunc); ok { c.Transport = RoundTrip(c.Transport, tf) diff --git a/libsq/core/ioz/httpz/httpz_test.go b/libsq/core/ioz/httpz/httpz_test.go index c1de8d8f7..e1fe9f09c 100644 --- a/libsq/core/ioz/httpz/httpz_test.go +++ b/libsq/core/ioz/httpz/httpz_test.go @@ -76,7 +76,7 @@ func TestOptHeaderTimeout_correct_error(t *testing.T) { t.Log(err) require.Error(t, err) require.Nil(t, resp) - require.Contains(t, err.Error(), "http response not received within") + require.Contains(t, err.Error(), "http response header not received within") require.True(t, errors.Is(err, context.DeadlineExceeded)) // Now let's try again, with a shorter server delay, so the diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index 1e8d37d85..2cf7aafa4 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -8,9 +8,11 @@ import ( "time" "github.com/neilotoole/sq/cli/buildinfo" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/loz" ) // Opt is an option that can be passed to [NewClient] to @@ -57,20 +59,37 @@ func OptUserAgent(ua string) TripFunc { } } +// contextCause is a TripFunc that extracts the context.Cause error +// from the request context, if any, and returns it as the error. +func contextCause() TripFunc { + return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { + resp, err := next.RoundTrip(req) + if err != nil { + if cause := context.Cause(req.Context()); cause != nil { + err = cause + } + } + return resp, err + } +} + // DefaultUserAgent is the default User-Agent header value, // as used by [NewDefaultClient]. var DefaultUserAgent = OptUserAgent(buildinfo.Get().UserAgent()) -// OptRequestTimeout is passed to [NewClient] to set the total request timeout. -// If timeout is zero, this is a no-op. +// OptRequestTimeout is passed to [NewClient] to set the total request timeout, +// including reading the body. This is basically the same as a traditional +// request timeout via context.WithTimeout. If timeout is zero, this is no-op. // // Contrast with [OptHeaderTimeout]. func OptRequestTimeout(timeout time.Duration) TripFunc { if timeout <= 0 { return NopTripFunc } + return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { - timeoutErr := errors.New("http request timeout") + timeoutErr := errz.Wrapf(context.DeadlineExceeded, + "http request not completed within %s timeout", timeout) ctx, cancelFn := context.WithTimeoutCause(req.Context(), timeout, timeoutErr) resp, err := next.RoundTrip(req.WithContext(ctx)) @@ -93,11 +112,14 @@ func OptRequestTimeout(timeout time.Duration) TripFunc { return resp, nil } - // We've got an error + // We've got an error. It may or may not be our timeout error. + // Either which way, we need to cancel the context. defer cancelFn() if errors.Is(context.Cause(ctx), timeoutErr) { - lg.FromContext(ctx).Warn("HTTP request not completed within timeout XYZ", // FIXME: delete + // If it is our timeout error, we log it. + + lg.FromContext(ctx).Warn("HTTP request not completed within timeout", lga.Timeout, timeout, lga.URL, req.URL.String()) } @@ -107,10 +129,8 @@ func OptRequestTimeout(timeout time.Duration) TripFunc { // OptHeaderTimeout is passed to [NewClient] to set a timeout for just // getting the initial response headers. This is useful if you expect -// a response within, say, 5 seconds, but you expect the body to take longer -// to read. If bodyTimeout > 0, it is applied to the total lifecycle of -// the request and response, including reading the response body. -// If timeout <= zero, this is a no-op. +// a response within, say, 2 seconds, but you expect the body to take longer +// to read. // // Contrast with [OptRequestTimeout]. func OptHeaderTimeout(timeout time.Duration) TripFunc { @@ -119,43 +139,54 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { } return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { timerCancelCh := make(chan struct{}) + ctx, cancelFn := context.WithCancelCause(req.Context()) + t := time.NewTimer(timeout) go func() { - t := time.NewTimer(timeout) defer t.Stop() select { case <-ctx.Done(): case <-t.C: + cancelErr := errz.Wrapf(context.DeadlineExceeded, + "http response header not received within %s timeout", timeout) + lg.FromContext(ctx).Warn("HTTP header response not received within timeout", lga.Timeout, timeout, lga.URL, req.URL.String()) - cancelFn(context.DeadlineExceeded) + + cancelFn(cancelErr) case <-timerCancelCh: // Stop the timer goroutine. } }() - resp, err := next.RoundTrip(req.WithContext(ctx)) - close(timerCancelCh) - if err != nil && errors.Is(err, ctx.Err()) { - // The lower-down RoundTripper probably returned ctx.Err(), - // not context.Cause(), so we swap it around here. - if cause := context.Cause(ctx); cause != nil { - err = cause + resp, err := errz.Return(next.RoundTrip(req.WithContext(ctx))) + + if errz.IsErrContext(err) { + if loz.Take(ctx.Done()) { + // The lower-down RoundTripper probably returned ctx.Err(), + // not context.Cause(), so we swap it around here. + if cause := context.Cause(ctx); cause != nil { + err = cause + } } } + + // Signal completion of the timer goroutine (it may have already completed). + close(timerCancelCh) + // Don't leak resources; ensure that cancelFn is eventually called. switch { case err != nil: - // It's probable that cancelFn has already been called by the - // timer goroutine, but we call it again just in case. + // An error has occurred. It's probable that cancelFn has already been + // called by the timer goroutine, but we call it again just in case. cancelFn(context.DeadlineExceeded) case resp != nil && resp.Body != nil: - // Wrap resp.Body with a ReadCloserNotifier, so that cancelFn // is called when the body is closed. - resp.Body = ioz.ReadCloserNotifier(resp.Body, func(error) { cancelFn(context.DeadlineExceeded) }) + resp.Body = ioz.ReadCloserNotifier(resp.Body, + func(error) { cancelFn(context.DeadlineExceeded) }) default: - // Not sure if this can actually happen, but just in case. + // Not sure if this can actually be reached, but just in case. cancelFn(context.DeadlineExceeded) } diff --git a/libsq/core/lg/devlog/devlog_test.go b/libsq/core/lg/devlog/devlog_test.go index c8756d08b..0e380e6e0 100644 --- a/libsq/core/lg/devlog/devlog_test.go +++ b/libsq/core/lg/devlog/devlog_test.go @@ -1,23 +1,22 @@ package devlog_test import ( - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/lg/lgt" "log/slog" "os" "testing" + + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgt" ) func TestDevlog(t *testing.T) { - log := lgt.New(t) - //log.Debug("huzzah") + // log.Debug("huzzah") err := errz.New("oh noes") - //stack := errs.Stacks(err) + // stack := errs.Stacks(err) // lga.Stack, errz.Stacks(err) log.Error("bah", lga.Err, err) - } func TestDevlogTextHandler(t *testing.T) { @@ -27,13 +26,12 @@ func TestDevlogTextHandler(t *testing.T) { h := slog.NewTextHandler(os.Stdout, o) log := slog.New(h) - //log := lgt.New(t) - //log.Debug("huzzah") + // log := lgt.New(t) + // log.Debug("huzzah") err := errz.New("oh noes") - //stack := errs.Stacks(err) + // stack := errs.Stacks(err) // lga.Stack, errz.Stacks(err) log.Error("bah", lga.Err, err) - } func ReplaceAttr(groups []string, a slog.Attr) slog.Attr { diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 0c184920c..2f0c08b4d 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -55,8 +55,8 @@ package tint import ( "context" "encoding" + "errors" "fmt" - "github.com/neilotoole/sq/libsq/core/errz" "io" "log/slog" "path/filepath" @@ -66,6 +66,8 @@ import ( "sync" "time" "unicode" + + "github.com/neilotoole/sq/libsq/core/errz" ) // ANSI modes @@ -301,7 +303,7 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { } var count int - for _, stack := range stacks { + for i, stack := range stacks { if stack == nil { continue } @@ -318,12 +320,21 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { } if stack.Error != nil { + buf.WriteStringIf(!h.noColor, ansiStackErr) buf.WriteString(stack.Error.Error()) buf.WriteStringIf(!h.noColor, ansiReset) buf.WriteByte(' ') buf.WriteStringIf(!h.noColor, ansiFaint) - buf.WriteStringIf(!h.noColor, fmt.Sprintf("%T", stack.Error)) + // Now we'll print the type of the error. + buf.WriteString(fmt.Sprintf("%T", stack.Error)) + if i == len(stacks)-1 { + // If we're on the final stack, and there's a cause underneath, + // then we print that type too. + if cause := errors.Unwrap(stack.Error); cause != nil { + buf.WriteString(fmt.Sprintf(" -> %T", cause)) + } + } buf.WriteStringIf(!h.noColor, ansiResetFaint) buf.WriteByte('\n') } @@ -339,7 +350,6 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { if count > 0 { buf.WriteByte('\n') } - } func (h *handler) WithAttrs(attrs []slog.Attr) slog.Handler { diff --git a/libsq/core/loz/loz.go b/libsq/core/loz/loz.go index a456d4c45..3109d3738 100644 --- a/libsq/core/loz/loz.go +++ b/libsq/core/loz/loz.go @@ -153,3 +153,27 @@ func ZeroIfNil[T comparable](t *T) T { return *t } + +// Take returns true if ch is non-nil and a value is available +// from ch, or false otherwise. This is useful in for a succinct +// "if done" idiom, e.g.: +// +// if someCondition && loz.Take(doneCh) { +// return +// } +// +// Note that this function does read from the channel, so it's mostly +// intended for use with "done" channels, where the caller is not +// interested in the value sent on the channel, only the fact that +// a value was sent, e.g. by closing the channel. +func Take[C any](ch <-chan C) bool { + if ch == nil { + return false + } + select { + case <-ch: + return true + default: + return false + } +} diff --git a/libsq/source/files.go b/libsq/source/files.go index a91b81ded..96897dd05 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -438,7 +438,7 @@ func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) fpath, ok = isFpath(loc) if ok { // we have a legitimate fpath - return errz.Tuple(os.Open(fpath)) + return errz.Return(os.Open(fpath)) } // It's not a local file path, maybe it's remote (http) var u *url.URL From 993a402da707191d19771bc1341e24a0d44ce76e Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 15 Dec 2023 17:21:23 -0700 Subject: [PATCH 127/195] Refining error printing --- cli/cmd_x.go | 6 +- cli/error.go | 6 +- cli/output.go | 8 --- cli/output/jsonw/errorwriter.go | 8 +-- cli/output/jsonw/jsonw_test.go | 3 +- cli/output/printing.go | 20 ++++++- cli/output/tablew/errorwriter.go | 54 ++++++++++++++--- cli/output/writers.go | 6 +- libsq/core/errz/errz.go | 18 ++++++ libsq/core/lg/devlog/tint/handler.go | 88 ++++++++++++++++------------ libsq/core/stringz/stringz.go | 10 ++++ libsq/core/stringz/stringz_test.go | 12 ++++ 12 files changed, 171 insertions(+), 68 deletions(-) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 94850f36b..f6eddeb5d 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -3,6 +3,8 @@ package cli import ( "bufio" "fmt" + "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/libsq/core/errz" "net/url" "os" "time" @@ -150,7 +152,7 @@ func newXDownloadCmd() *cobra.Command { $ sq x download https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv `, } - + cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) return cmd } @@ -186,7 +188,7 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { switch { case len(h.Errors) > 0: - return h.Errors[0] + return errz.Wrap(h.Errors[0], "huzzah") case len(h.WriteErrors) > 0: return h.WriteErrors[0] case len(h.CachedFiles) > 0: diff --git a/cli/error.go b/cli/error.go index 98a4a461f..4635b6f62 100644 --- a/cli/error.go +++ b/cli/error.go @@ -48,12 +48,12 @@ func printError(ctx context.Context, ru *run.Run, err error) { log.Error("EXECUTION FAILED", lga.Err, err, lga.Cmd, cmdName, lga.Stack, errz.Stacks(err)) - err = humanizeError(err) + humanErr := humanizeError(err) if ru != nil { if wrtrs := ru.Writers; wrtrs != nil && wrtrs.Error != nil { // If we have an errorWriter, we print to it // and return. - wrtrs.Error.Error(err) + wrtrs.Error.Error(err, humanErr) return } @@ -94,7 +94,7 @@ func printError(ctx context.Context, ru *run.Run, err error) { if bootstrapIsFormatJSON(ru) { // The user wants JSON, either via defaults or flags. jw := jsonw.NewErrorWriter(log, errOut, pr) - jw.Error(err) + jw.Error(err, humanErr) return } diff --git a/cli/output.go b/cli/output.go index 25ff7114f..bffe2e3dd 100644 --- a/cli/output.go +++ b/cli/output.go @@ -28,7 +28,6 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/stringz" @@ -473,16 +472,9 @@ func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Option pb.Stop() }) - // On first write to stderr, we remove the progress widget. - //errOut2 = ioz.NotifyOnceWriter(errOut2, func() { - // lg.FromContext(ctx).Debug("Error stream is being written to; removing progress widget") - // pb.Stop() - //}) // FIXME: delete cmd.SetContext(progress.NewContext(ctx, pb)) } - lg.FromContext(cmd.Context()).Debug("Constructed output.Printing", lga.Val, pr) - return pr, out2, errOut2 } diff --git a/cli/output/jsonw/errorwriter.go b/cli/output/jsonw/errorwriter.go index d8ad1e15e..2b0156f1a 100644 --- a/cli/output/jsonw/errorwriter.go +++ b/cli/output/jsonw/errorwriter.go @@ -22,16 +22,16 @@ func NewErrorWriter(log *slog.Logger, out io.Writer, pr *output.Printing) output } // Error implements output.ErrorWriter. -func (w *errorWriter) Error(err error) { +func (w *errorWriter) Error(systemErr error, humanErr error) { var errMsg string var stack []string - if err == nil { + if systemErr == nil { errMsg = "nil error" } else { - errMsg = err.Error() + errMsg = systemErr.Error() if w.pr.Verbose { - for _, st := range errz.Stacks(err) { + for _, st := range errz.Stacks(systemErr) { s := fmt.Sprintf("%+v", st) stack = append(stack, s) } diff --git a/cli/output/jsonw/jsonw_test.go b/cli/output/jsonw/jsonw_test.go index d93364365..2f1bc0a7a 100644 --- a/cli/output/jsonw/jsonw_test.go +++ b/cli/output/jsonw/jsonw_test.go @@ -215,7 +215,8 @@ func TestErrorWriter(t *testing.T) { pr.EnableColor(tc.color) errw := jsonw.NewErrorWriter(lgt.New(t), buf, pr) - errw.Error(errz.New("err1")) + e := errz.New("err1") + errw.Error(e, e) got := buf.String() require.Equal(t, tc.want, got) diff --git a/cli/output/printing.go b/cli/output/printing.go index 142b51e8f..709173df8 100644 --- a/cli/output/printing.go +++ b/cli/output/printing.go @@ -159,6 +159,15 @@ type Printing struct { // Success is the color for success elements. Success *color.Color + // Stack is the color for stack traces. + Stack *color.Color + + // StackError is the color for errors attached to a stack trace. + StackError *color.Color + + // StackErrorType is the color for the error types attached to a stack trace. + StackErrorType *color.Color + // Warning is the color for warning elements. Warning *color.Color } @@ -205,6 +214,9 @@ func NewPrinting() *Printing { Number: color.New(color.FgCyan), Punc: color.New(color.Bold), String: color.New(color.FgGreen), + Stack: color.New(color.Faint), + StackError: color.New(color.FgYellow, color.Faint), + StackErrorType: color.New(color.FgGreen, color.Faint), Success: color.New(color.FgGreen, color.Bold), Warning: color.New(color.FgYellow), } @@ -257,6 +269,9 @@ func (pr *Printing) Clone() *Printing { pr2.Punc = lo.ToPtr(*pr.Punc) pr2.String = lo.ToPtr(*pr.String) pr2.Success = lo.ToPtr(*pr.Success) + pr2.Stack = lo.ToPtr(*pr.Stack) + pr2.StackError = lo.ToPtr(*pr.StackError) + pr2.StackErrorType = lo.ToPtr(*pr.StackErrorType) pr2.Warning = lo.ToPtr(*pr.Warning) return pr2 @@ -289,8 +304,9 @@ func (pr *Printing) colors() []*color.Color { pr.DiffHeader, pr.DiffMinus, pr.DiffPlus, pr.DiffNormal, pr.DiffSection, pr.Disabled, pr.Enabled, pr.Error, pr.Faint, pr.Handle, pr.Header, pr.Hilite, - pr.Key, pr.Location, pr.Normal, pr.Null, pr.Number, - pr.Punc, pr.String, pr.Success, pr.Warning, + pr.Key, pr.Location, pr.Normal, pr.Null, pr.Number, pr.Punc, + pr.Stack, pr.StackError, pr.StackErrorType, + pr.String, pr.Success, pr.Warning, } } diff --git a/cli/output/tablew/errorwriter.go b/cli/output/tablew/errorwriter.go index 523060bd7..1c31ede15 100644 --- a/cli/output/tablew/errorwriter.go +++ b/cli/output/tablew/errorwriter.go @@ -1,7 +1,9 @@ package tablew import ( + "bytes" "fmt" + "github.com/neilotoole/sq/libsq/core/stringz" "io" "strings" @@ -22,20 +24,54 @@ func NewErrorWriter(w io.Writer, pr *output.Printing) output.ErrorWriter { } // Error implements output.ErrorWriter. -func (w *errorWriter) Error(err error) { - fmt.Fprintln(w.w, w.pr.Error.Sprintf("sq: %v", err)) +func (w *errorWriter) Error(systemErr error, humanErr error) { + fmt.Fprintln(w.w, w.pr.Error.Sprintf("sq: %v", humanErr)) if !w.pr.Verbose { return } - stacks := errz.Stacks(err) - for i, stack := range stacks { - if i > 0 { - fmt.Fprintln(w.w) + stacks := errz.Stacks(systemErr) + if len(stacks) == 0 { + return + } + + var buf = &bytes.Buffer{} + var count int + for _, stack := range stacks { + if stack == nil { + continue } - s := fmt.Sprintf("%+v", stack) - s = strings.TrimSpace(s) - w.pr.Faint.Fprintln(w.w, s) + stackPrint := fmt.Sprintf("%+v", stack) + stackPrint = strings.ReplaceAll(strings.TrimSpace(stackPrint), "\n\t", "\n ") + if stackPrint == "" { + continue + } + + if count > 0 { + buf.WriteString("\n") + } + + if stack.Error != nil { + errTypes := stringz.TypeNames(errz.Tree(stack.Error)...) + for i, typ := range errTypes { + w.pr.StackErrorType.Fprint(buf, typ) + if i < len(errTypes)-1 { + w.pr.Faint.Fprint(buf, ":") + buf.WriteByte(' ') + } + } + buf.WriteByte('\n') + w.pr.StackError.Fprintln(buf, stack.Error.Error()) + } + + lines := strings.Split(stackPrint, "\n") + for _, line := range lines { + w.pr.Stack.Fprint(buf, line) + buf.WriteByte('\n') + } + count++ } + + buf.WriteTo(w.w) } diff --git a/cli/output/writers.go b/cli/output/writers.go index 7c715f508..138b5ea6e 100644 --- a/cli/output/writers.go +++ b/cli/output/writers.go @@ -97,8 +97,10 @@ type SourceWriter interface { // ErrorWriter outputs errors. type ErrorWriter interface { - // Error outputs err. - Error(err error) + // Error outputs error conditions. It's possible that systemErr and + // humanErr differ; systemErr is the error that occurred, and humanErr + // is the error that should be presented to the user. + Error(systemErr error, humanErr error) } // PingWriter writes ping results. diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index 77b76f568..c37f6cf90 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -112,3 +112,21 @@ func IsType[E error](err error) bool { var target E return errors.As(err, &target) } + +// Tree returns a slice of all the errors in err's tree. +func Tree(err error) []error { + if err == nil { + return nil + } + + var errs []error + for { + errs = append(errs, err) + err = errors.Unwrap(err) + if err == nil { + break + } + } + + return errs +} diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 2f0c08b4d..8976b5601 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -55,8 +55,8 @@ package tint import ( "context" "encoding" - "errors" "fmt" + "github.com/neilotoole/sq/libsq/core/stringz" "io" "log/slog" "path/filepath" @@ -73,21 +73,23 @@ import ( // ANSI modes // See: https://gist.github.com/JBlond/2fea43a3049b38287e5e9cefc87b2124 const ( - ansiAttr = "\033[36;2m" - ansiBlue = "\033[34m" - ansiBrightBlue = "\033[94m" - ansiBrightGreen = "\033[92m" - ansiBrightGreenBold = "\033[1;92m" - ansiBrightRed = "\033[91m" - ansiBrightRedBold = "\033[1;91m" - ansiBrightRedFaint = "\033[91;2m" - ansiBrightYellow = "\033[93m" - ansiFaint = "\033[2m" - ansiReset = "\033[0m" - ansiResetFaint = "\033[22m" - ansiStack = "\033[0;35m" - ansiYellowBold = "\033[1;33m" - ansiStackErr = ansiYellowBold + ansiAttr = "\033[36;2m" + ansiBlue = "\033[34m" + ansiBrightBlue = "\033[94m" + ansiBrightGreen = "\033[92m" + ansiBrightGreenBold = "\033[1;92m" + ansiBrightGreenFaint = "\033[92;2m" + ansiBrightRed = "\033[91m" + ansiBrightRedBold = "\033[1;91m" + ansiBrightRedFaint = "\033[91;2m" + ansiBrightYellow = "\033[93m" + ansiFaint = "\033[2m" + ansiReset = "\033[0m" + ansiResetFaint = "\033[22m" + ansiStack = "\033[0;35m" + ansiYellowBold = "\033[1;33m" + ansiStackErr = ansiYellowBold + ansiStackErrType = ansiBrightGreenFaint ) const errKey = "err" @@ -302,52 +304,64 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { } } - var count int - for i, stack := range stacks { + var printed int + for _, stack := range stacks { if stack == nil { continue } - v := fmt.Sprintf("%+v", stack) - v = strings.TrimSpace(v) - v = strings.ReplaceAll(v, "\n\t", "\n ") - if v == "" { + stackPrint := fmt.Sprintf("%+v", stack) + stackPrint = strings.ReplaceAll(strings.TrimSpace(stackPrint), "\n\t", "\n ") + if stackPrint == "" { continue } - if count > 0 { + if printed > 0 { buf.WriteString("\n") } if stack.Error != nil { + errTypes := stringz.TypeNames(errz.Tree(stack.Error)...) + for j, typ := range errTypes { + buf.WriteStringIf(!h.noColor, ansiStackErrType) + buf.WriteString(typ) + buf.WriteStringIf(!h.noColor, ansiReset) + if j < len(errTypes)-1 { + buf.WriteStringIf(!h.noColor, ansiFaint) + buf.WriteByte(':') + buf.WriteStringIf(!h.noColor, ansiResetFaint) + buf.WriteByte(' ') + } + } + buf.WriteByte('\n') buf.WriteStringIf(!h.noColor, ansiStackErr) buf.WriteString(stack.Error.Error()) buf.WriteStringIf(!h.noColor, ansiReset) buf.WriteByte(' ') - buf.WriteStringIf(!h.noColor, ansiFaint) - // Now we'll print the type of the error. - buf.WriteString(fmt.Sprintf("%T", stack.Error)) - if i == len(stacks)-1 { - // If we're on the final stack, and there's a cause underneath, - // then we print that type too. - if cause := errors.Unwrap(stack.Error); cause != nil { - buf.WriteString(fmt.Sprintf(" -> %T", cause)) - } - } - buf.WriteStringIf(!h.noColor, ansiResetFaint) + //buf.WriteStringIf(!h.noColor, ansiFaint) + //// Now we'll print the type of the error. + //buf.WriteString(fmt.Sprintf("%T", stack.Error)) + //if i == len(stacks)-1 { + // // If we're on the final stack, and there's a cause underneath, + // // then we print that type too. + // if cause := errors.Unwrap(stack.Error); cause != nil { + // buf.WriteString(fmt.Sprintf(" -> %T", cause)) + // } + //} + //buf.WriteStringIf(!h.noColor, ansiResetFaint) buf.WriteByte('\n') } - lines := strings.Split(v, "\n") + lines := strings.Split(stackPrint, "\n") for _, line := range lines { buf.WriteStringIf(!h.noColor, ansiStack) buf.WriteString(line) buf.WriteStringIf(!h.noColor, ansiReset) buf.WriteByte('\n') } - count++ + printed++ } - if count > 0 { + if printed > 0 { buf.WriteByte('\n') } } diff --git a/libsq/core/stringz/stringz.go b/libsq/core/stringz/stringz.go index 2db3ad4ab..19954cd2c 100644 --- a/libsq/core/stringz/stringz.go +++ b/libsq/core/stringz/stringz.go @@ -792,3 +792,13 @@ func SanitizeFilename(name string) string { return name } } + +// TypeNames returns the go type of each element of a, as +// rendered by fmt "%T". +func TypeNames[T any](a ...T) []string { + types := make([]string, len(a)) + for i := range a { + types[i] = fmt.Sprintf("%T", a[i]) + } + return types +} diff --git a/libsq/core/stringz/stringz_test.go b/libsq/core/stringz/stringz_test.go index 208c79960..4d5e617c6 100644 --- a/libsq/core/stringz/stringz_test.go +++ b/libsq/core/stringz/stringz_test.go @@ -1,6 +1,8 @@ package stringz_test import ( + "errors" + "github.com/neilotoole/sq/libsq/core/errz" "strconv" "strings" "testing" @@ -684,3 +686,13 @@ func TestSanitizeFilename(t *testing.T) { }) } } + +func TestTypeNames(t *testing.T) { + errs := []error{errors.New("stdlib"), errz.New("errz")} + names := stringz.TypeNames(errs...) + require.Equal(t, []string{"*errors.errorString", "*errz.fundamental"}, names) + + a := []any{1, "hello", true, errs} + names = stringz.TypeNames(a...) + require.Equal(t, []string{"int", "string", "bool", "[]error"}, names) +} From 7decb56d42779bd349b29c9ccb8892740e87e80f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 15 Dec 2023 17:37:13 -0700 Subject: [PATCH 128/195] Refining error printing --- cli/output/tablew/errorwriter.go | 2 +- libsq/core/errz/stack.go | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cli/output/tablew/errorwriter.go b/cli/output/tablew/errorwriter.go index 1c31ede15..3b03e4b84 100644 --- a/cli/output/tablew/errorwriter.go +++ b/cli/output/tablew/errorwriter.go @@ -43,7 +43,7 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { } stackPrint := fmt.Sprintf("%+v", stack) - stackPrint = strings.ReplaceAll(strings.TrimSpace(stackPrint), "\n\t", "\n ") + //stackPrint = strings.ReplaceAll(strings.TrimSpace(stackPrint), "\n\t", "\n ") if stackPrint == "" { continue } diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index d93024d31..3241a6d40 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -63,7 +63,7 @@ func (f Frame) name() string { // Format accepts flags that alter the printing of some verbs, as follows: // // %+s function name and path of source file relative to the compile time -// GOPATH separated by \n\t (\n\t) +// GOPATH separated by \nSPSP (\nSPSP) // %+v equivalent to %+s:%d func (f Frame) Format(s fmt.State, verb rune) { switch verb { @@ -71,7 +71,7 @@ func (f Frame) Format(s fmt.State, verb rune) { switch { case s.Flag('+'): _, _ = io.WriteString(s, f.name()) - _, _ = io.WriteString(s, "\n\t") + _, _ = io.WriteString(s, "\n ") _, _ = io.WriteString(s, f.file()) default: _, _ = io.WriteString(s, path.Base(f.file())) @@ -119,8 +119,10 @@ func (st *StackTrace) Format(s fmt.State, verb rune) { case 'v': switch { case s.Flag('+'): - for _, f := range st.Frames { - _, _ = io.WriteString(s, "\n") + for i, f := range st.Frames { + if i != 0 { + _, _ = io.WriteString(s, "\n") + } f.Format(s, verb) } case s.Flag('#'): From 073c4e7f2b75edff6c4cac47fac38c360d24bc31 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 03:33:35 -0700 Subject: [PATCH 129/195] Refining error printing --- cli/error.go | 14 ------ cli/output/jsonw/errorwriter.go | 68 +++++++++++++++++++---------- libsq/core/ioz/download/download.go | 13 ++++++ 3 files changed, 59 insertions(+), 36 deletions(-) diff --git a/cli/error.go b/cli/error.go index 4635b6f62..44b2f1d90 100644 --- a/cli/error.go +++ b/cli/error.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/spf13/cobra" "github.com/spf13/pflag" - "net/url" "os" "strings" @@ -158,19 +157,6 @@ func humanizeError(err error) error { return nil } - // Download timeout errors are typically wrapped in an url.Error, resulting - // in a message like: - // - // Get "https://example.com": http response header not received within 1ms timeout - // - // We want to trim off that prefix, but we only do that if there's a wrapped - // error beneath (which should be the case). - if errz.IsType[*url.Error](err) && errors.Is(err, context.DeadlineExceeded) { - if e := errors.Unwrap(err); e != nil { - err = e - } - } - switch { // Friendlier messages for context errors. default: diff --git a/cli/output/jsonw/errorwriter.go b/cli/output/jsonw/errorwriter.go index 2b0156f1a..f416ec651 100644 --- a/cli/output/jsonw/errorwriter.go +++ b/cli/output/jsonw/errorwriter.go @@ -2,6 +2,7 @@ package jsonw import ( "fmt" + "github.com/neilotoole/sq/libsq/core/stringz" "io" "log/slog" @@ -21,33 +22,56 @@ func NewErrorWriter(log *slog.Logger, out io.Writer, pr *output.Printing) output return &errorWriter{log: log, out: out, pr: pr} } +type errorDetail struct { + Error string `json:"error,"` + BaseError string `json:"base_error,omitempty"` + Stack []*stack `json:"stack,omitempty"` +} + +type stackError struct { + Message string `json:"msg"` + Tree []string `json:"tree,omitempty"` +} + +type stack struct { + Error *stackError `json:"error,omitempty"` + Trace string `json:"trace,omitempty"` +} + // Error implements output.ErrorWriter. func (w *errorWriter) Error(systemErr error, humanErr error) { - var errMsg string - var stack []string - - if systemErr == nil { - errMsg = "nil error" - } else { - errMsg = systemErr.Error() - if w.pr.Verbose { - for _, st := range errz.Stacks(systemErr) { - s := fmt.Sprintf("%+v", st) - stack = append(stack, s) - } - } + pr := w.pr.Clone() + pr.String = pr.Warning + //pr.Key = pr.Warning + + if !w.pr.Verbose { + ed := errorDetail{Error: humanErr.Error()} + _ = writeJSON(w.out, pr, ed) + return } - t := struct { - Error string `json:"error"` - Stack []string `json:"stack,omitempty"` - }{ - Error: errMsg, - Stack: stack, + ed := errorDetail{ + Error: humanErr.Error(), + BaseError: systemErr.Error(), } - pr := w.pr.Clone() - pr.String = pr.Error + stacks := errz.Stacks(systemErr) + if len(stacks) > 0 { + for _, sysStack := range stacks { + if sysStack == nil { + continue + } + + st := &stack{ + Trace: fmt.Sprintf("%+v", sysStack), + Error: &stackError{ + Message: sysStack.Error.Error(), + Tree: stringz.TypeNames(errz.Tree(sysStack.Error)...), + }} + + ed.Stack = append(ed.Stack, st) + } + } - _ = writeJSON(w.out, pr, t) + _ = writeJSON(w.out, pr, ed) } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 4bb1511a5..c31037faa 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -11,6 +11,7 @@ package download import ( "bufio" "context" + "errors" "io" "net/http" "net/url" @@ -303,6 +304,18 @@ func (dl *Download) get(req *http.Request, h Handler) { func (dl *Download) do(req *http.Request) (*http.Response, error) { resp, err := dl.c.Do(req) if err != nil { + // Download timeout errors are typically wrapped in an url.Error, resulting + // in a message like: + // + // Get "https://example.com": http response header not received within 1ms timeout + // + // We want to trim off that `GET "URL"` prefix, but we only do that if + // there's a wrapped error beneath (which should be the case). + if errz.IsType[*url.Error](err) && errors.Is(err, context.DeadlineExceeded) { + if e := errors.Unwrap(err); e != nil { + err = e + } + } return nil, err } From db7691dc35a0bfb6f92b65caf916945637c0e1ae Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 04:01:50 -0700 Subject: [PATCH 130/195] errz: no more errz.withMessage --- libsq/core/errz/errors.go | 114 ++++++++++++++++---------------------- libsq/core/errz/errz.go | 1 + 2 files changed, 49 insertions(+), 66 deletions(-) diff --git a/libsq/core/errz/errors.go b/libsq/core/errz/errors.go index 08941d631..3592b8576 100644 --- a/libsq/core/errz/errors.go +++ b/libsq/core/errz/errors.go @@ -163,6 +163,7 @@ func (f *fundamental) LogValue() slog.Value { type withStack struct { error + msg *string *stack } @@ -181,6 +182,13 @@ func (w *withStack) StackTrace() *StackTrace { return st } +func (w *withStack) Error() string { + if w.msg == nil { + return w.error.Error() + } + return *w.msg + ": " + w.error.Error() +} + func (w *withStack) Cause() error { return w.error } // Unwrap provides compatibility for Go 1.13 error chains. @@ -214,12 +222,10 @@ func Wrap(err error, message string) error { if err == nil { return nil } - err = &withMessage{ - cause: err, - msg: message, - } + return &withStack{ err, + &message, callers(), } } @@ -231,71 +237,50 @@ func Wrapf(err error, format string, args ...any) error { if err == nil { return nil } - err = &withMessage{ - cause: err, - msg: fmt.Sprintf(format, args...), - } + //err = &withMessage{ + // cause: err, + // msg: fmt.Sprintf(format, args...), + //} + msg := fmt.Sprintf(format, args...) return &withStack{ err, + &msg, callers(), } } -// WithMessage annotates err with a new message. -// If err is nil, WithMessage returns nil. -func WithMessage(err error, message string) error { - if err == nil { - return nil - } - return &withMessage{ - cause: err, - msg: message, - } -} - -// WithMessagef annotates err with the format specifier. -// If err is nil, WithMessagef returns nil. -func WithMessagef(err error, format string, args ...any) error { - if err == nil { - return nil - } - return &withMessage{ - cause: err, - msg: fmt.Sprintf(format, args...), - } -} - -type withMessage struct { //nolint:errname - cause error - msg string -} - -func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } -func (w *withMessage) Cause() error { return w.cause } - -// LogValue implements slog.LogValuer. -func (w *withMessage) LogValue() slog.Value { - return logValue(w) -} - -// Unwrap provides compatibility for Go 1.13 error chains. -func (w *withMessage) Unwrap() error { return w.cause } - -func (w *withMessage) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - _, _ = fmt.Fprintf(s, "%+v\n", w.Cause()) - _, _ = io.WriteString(s, w.msg) - return - } - fallthrough - case 's', 'q': - _, _ = io.WriteString(s, w.Error()) - } -} +// +//type withMessage struct { //nolint:errname +// cause error +// msg string +//} +// +//func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } +//func (w *withMessage) Cause() error { return w.cause } +// +//// LogValue implements slog.LogValuer. +//func (w *withMessage) LogValue() slog.Value { +// return logValue(w) +//} +// +//// Unwrap provides compatibility for Go 1.13 error chains. +//func (w *withMessage) Unwrap() error { return w.cause } +// +//func (w *withMessage) Format(s fmt.State, verb rune) { +// switch verb { +// case 'v': +// if s.Flag('+') { +// _, _ = fmt.Fprintf(s, "%+v\n", w.Cause()) +// _, _ = io.WriteString(s, w.msg) +// return +// } +// fallthrough +// case 's', 'q': +// _, _ = io.WriteString(s, w.Error()) +// } +//} -// Cause returns the underlying cause of the error, if possible. +// Cause returns the underlying *root* cause of the error, if possible. // An error value has a cause if it implements the following // interface: // @@ -304,10 +289,7 @@ func (w *withMessage) Format(s fmt.State, verb rune) { // } // // If the error does not implement Cause, the original error will -// be returned. If the error is nil, nil will be returned without further -// investigation. -// -// Deprecated: Use errors.Unwrap or errors.As. +// be returned. Nil is returned if err is nil. func Cause(err error) error { type causer interface { Cause() error diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index c37f6cf90..93708ed77 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -25,6 +25,7 @@ func Err(err error) error { } return &withStack{ err, + nil, callers(), } } From 36f673aa644378f9f262954a7fc3308ee435be7c Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 05:56:23 -0700 Subject: [PATCH 131/195] errz: largely refactored --- cli/output/jsonw/errorwriter.go | 6 +- cli/output/tablew/errorwriter.go | 4 +- libsq/core/errz/errors.go | 218 +++++++++++++++++---------- libsq/core/errz/errz.go | 77 ++++------ libsq/core/errz/errz_test.go | 46 +++--- libsq/core/errz/errz_types.go | 5 + libsq/core/errz/stack.go | 9 +- libsq/core/ioz/download/download.go | 2 +- libsq/core/ioz/lockfile/lockfile.go | 2 +- libsq/core/lg/devlog/tint/handler.go | 2 +- 10 files changed, 210 insertions(+), 161 deletions(-) diff --git a/cli/output/jsonw/errorwriter.go b/cli/output/jsonw/errorwriter.go index f416ec651..6fb7a3386 100644 --- a/cli/output/jsonw/errorwriter.go +++ b/cli/output/jsonw/errorwriter.go @@ -5,6 +5,7 @@ import ( "github.com/neilotoole/sq/libsq/core/stringz" "io" "log/slog" + "strings" "github.com/neilotoole/sq/cli/output" "github.com/neilotoole/sq/libsq/core/errz" @@ -42,7 +43,6 @@ type stack struct { func (w *errorWriter) Error(systemErr error, humanErr error) { pr := w.pr.Clone() pr.String = pr.Warning - //pr.Key = pr.Warning if !w.pr.Verbose { ed := errorDetail{Error: humanErr.Error()} @@ -63,10 +63,10 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { } st := &stack{ - Trace: fmt.Sprintf("%+v", sysStack), + Trace: strings.ReplaceAll(fmt.Sprintf("%+v", sysStack), "\n\t", "\n "), Error: &stackError{ Message: sysStack.Error.Error(), - Tree: stringz.TypeNames(errz.Tree(sysStack.Error)...), + Tree: stringz.TypeNames(errz.Chain(sysStack.Error)...), }} ed.Stack = append(ed.Stack, st) diff --git a/cli/output/tablew/errorwriter.go b/cli/output/tablew/errorwriter.go index 3b03e4b84..667036c6a 100644 --- a/cli/output/tablew/errorwriter.go +++ b/cli/output/tablew/errorwriter.go @@ -43,7 +43,7 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { } stackPrint := fmt.Sprintf("%+v", stack) - //stackPrint = strings.ReplaceAll(strings.TrimSpace(stackPrint), "\n\t", "\n ") + stackPrint = strings.ReplaceAll(strings.TrimSpace(stackPrint), "\n\t", "\n ") if stackPrint == "" { continue } @@ -53,7 +53,7 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { } if stack.Error != nil { - errTypes := stringz.TypeNames(errz.Tree(stack.Error)...) + errTypes := stringz.TypeNames(errz.Chain(stack.Error)...) for i, typ := range errTypes { w.pr.StackErrorType.Fprint(buf, typ) if i < len(errTypes)-1 { diff --git a/libsq/core/errz/errors.go b/libsq/core/errz/errors.go index 3592b8576..ef76af37f 100644 --- a/libsq/core/errz/errors.go +++ b/libsq/core/errz/errors.go @@ -1,5 +1,7 @@ // Package errz provides simple error handling primitives. // +// FIXME: update docs +// // The traditional error handling idiom in Go is roughly akin to // // if err != nil { @@ -35,14 +37,14 @@ // for inspection. Any error value which implements this interface // // type causer interface { -// Cause() error +// UnwrapFully() error // } // -// can be inspected by errors.Cause. errors.Cause will recursively retrieve +// can be inspected by errors.UnwrapFully. errors.UnwrapFully will recursively retrieve // the topmost error that does not implement causer, which is assumed to be // the original cause. For example: // -// switch err := errors.Cause(err).(type) { +// switch err := errors.UnwrapFully(err).(type) { // case *MyError: // // handle specifically // default: @@ -57,7 +59,7 @@ // All error values returned from this package implement fmt.Formatter and can // be formatted by the fmt package. The following verbs are supported: // -// %s print the error. If the error has a Cause it will be +// %s print the error. If the error has a UnwrapFully it will be // printed recursively. // %v see %s // %+v extended format. Each Frame of the error's StackTrace will @@ -93,6 +95,7 @@ package errz import ( + "errors" "fmt" "io" "log/slog" @@ -101,7 +104,7 @@ import ( // New returns an error with the supplied message. // New also records the stack trace at the point it was called. func New(message string) error { - return &fundamental{ + return &withStack{ msg: message, stack: callers(), } @@ -111,65 +114,68 @@ func New(message string) error { // as a value that satisfies error. // Errorf also records the stack trace at the point it was called. func Errorf(format string, args ...any) error { - return &fundamental{ + return &withStack{ msg: fmt.Sprintf(format, args...), stack: callers(), } } -// fundamental is an error that has a message and a stack, but no caller. -type fundamental struct { - msg string - *stack -} - -func (f *fundamental) Error() string { return f.msg } - -func (f *fundamental) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - _, _ = io.WriteString(s, f.msg) - f.stack.Format(s, verb) - return - } - fallthrough - case 's': - _, _ = io.WriteString(s, f.msg) - case 'q': - _, _ = fmt.Fprintf(s, "{%s}", f.msg) - } -} - -var _ StackTracer = (*fundamental)(nil) - -// StackTrace implements StackTracer. -func (f *fundamental) StackTrace() *StackTrace { - if f == nil || f.stack == nil { - return nil - } - - st := f.stack.stackTrace() - if st != nil { - st.Error = f - } - return st -} - -// LogValue implements slog.LogValuer. -func (f *fundamental) LogValue() slog.Value { - return logValue(f) -} +// +//// fundamental is an error that has a message and a stack, but no caller. +//type fundamental struct { +// msg string +// *stack +//} +// +//func (f *fundamental) Error() string { return f.msg } +// +//func (f *fundamental) Format(s fmt.State, verb rune) { +// switch verb { +// case 'v': +// if s.Flag('+') { +// _, _ = io.WriteString(s, f.msg) +// f.stack.Format(s, verb) +// return +// } +// fallthrough +// case 's': +// _, _ = io.WriteString(s, f.msg) +// case 'q': +// _, _ = fmt.Fprintf(s, "{%s}", f.msg) +// } +//} +// +//var _ StackTracer = (*fundamental)(nil) +// +//// StackTrace implements StackTracer. +//func (f *fundamental) StackTrace() *StackTrace { +// if f == nil || f.stack == nil { +// return nil +// } +// +// st := f.stack.stackTrace() +// if st != nil { +// st.Error = f +// } +// return st +//} +// +//// LogValue implements slog.LogValuer. +//func (f *fundamental) LogValue() slog.Value { +// return logValue(f) +//} type withStack struct { error - msg *string + msg string *stack } var _ StackTracer = (*withStack)(nil) // StackTrace implements StackTracer. +// REVISIT: consider making StackTrace private, or removing +// it in favor of the Stack function. func (w *withStack) StackTrace() *StackTrace { if w == nil || w.stack == nil { return nil @@ -182,39 +188,97 @@ func (w *withStack) StackTrace() *StackTrace { return st } +// Error implements stdlib error interface. func (w *withStack) Error() string { - if w.msg == nil { + if w.msg == "" { + if w.error == nil { + return "" + } return w.error.Error() } - return *w.msg + ": " + w.error.Error() + if w.error == nil { + return w.msg + } + return w.msg + ": " + w.error.Error() } +// LogValue implements slog.LogValuer. +func (w *withStack) LogValue() slog.Value { + if w == nil { + return slog.Value{} + } + + attrs := make([]slog.Attr, 2, 4) + attrs[0] = slog.String("msg", w.Error()) + attrs[1] = slog.String("type", fmt.Sprintf("%T", w)) + + // If there's a wrapped error, "cause" and "type" will be + // for that wrapped error. + if w.error != nil { + attrs[1] = slog.String("cause", w.error.Error()) + attrs = append(attrs, slog.String("type", fmt.Sprintf("%T", w.error))) + } else { + // If there's no wrapped error, "type" will be the type of w. + attrs[1] = slog.String("type", fmt.Sprintf("%T", w)) + } + + return slog.GroupValue(attrs...) +} + +// UnwrapFully returns the underlying *root* cause of the error, if possible. +// +// Deprecated: get rid of UnwrapFully in favor of errors.Unwrap. func (w *withStack) Cause() error { return w.error } // Unwrap provides compatibility for Go 1.13 error chains. func (w *withStack) Unwrap() error { return w.error } +//func (f *fundamental) Format(s fmt.State, verb rune) { +// switch verb { +// case 'v': +// if s.Flag('+') { +// _, _ = io.WriteString(s, f.msg) +// f.stack.Format(s, verb) +// return +// } +// fallthrough +// case 's': +// _, _ = io.WriteString(s, f.msg) +// case 'q': +// _, _ = fmt.Fprintf(s, "{%s}", f.msg) +// } +//} + func (w *withStack) Format(s fmt.State, verb rune) { switch verb { case 'v': if s.Flag('+') { - _, _ = fmt.Fprintf(s, "%+v", w.Cause()) - w.stack.Format(s, verb) + if w.error == nil { + _, _ = io.WriteString(s, w.msg) + w.stack.Format(s, verb) + return + } else { + _, _ = fmt.Fprintf(s, "%+v", w.Cause()) + w.stack.Format(s, verb) + } return } fallthrough case 's': + if w.error == nil { + _, _ = io.WriteString(s, w.msg) + return + } _, _ = io.WriteString(s, w.Error()) case 'q': + if w.error == nil { + _, _ = fmt.Fprintf(s, "{%s}", w.msg) + return + } _, _ = fmt.Fprintf(s, "{%s}", w.Error()) } } -// LogValue implements slog.LogValuer. -func (w *withStack) LogValue() slog.Value { - return logValue(w) -} - // Wrap returns an error annotating err with a stack trace // at the point Wrap is called, and the supplied message. // If err is nil, Wrap returns nil. @@ -225,7 +289,7 @@ func Wrap(err error, message string) error { return &withStack{ err, - &message, + message, callers(), } } @@ -241,10 +305,9 @@ func Wrapf(err error, format string, args ...any) error { // cause: err, // msg: fmt.Sprintf(format, args...), //} - msg := fmt.Sprintf(format, args...) return &withStack{ err, - &msg, + fmt.Sprintf(format, args...), callers(), } } @@ -256,7 +319,7 @@ func Wrapf(err error, format string, args ...any) error { //} // //func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } -//func (w *withMessage) Cause() error { return w.cause } +//func (w *withMessage) UnwrapFully() error { return w.cause } // //// LogValue implements slog.LogValuer. //func (w *withMessage) LogValue() slog.Value { @@ -270,7 +333,7 @@ func Wrapf(err error, format string, args ...any) error { // switch verb { // case 'v': // if s.Flag('+') { -// _, _ = fmt.Fprintf(s, "%+v\n", w.Cause()) +// _, _ = fmt.Fprintf(s, "%+v\n", w.UnwrapFully()) // _, _ = io.WriteString(s, w.msg) // return // } @@ -280,27 +343,20 @@ func Wrapf(err error, format string, args ...any) error { // } //} -// Cause returns the underlying *root* cause of the error, if possible. -// An error value has a cause if it implements the following -// interface: -// -// type causer interface { -// Cause() error -// } -// -// If the error does not implement Cause, the original error will -// be returned. Nil is returned if err is nil. -func Cause(err error) error { - type causer interface { - Cause() error +// UnwrapFully returns the underlying *root* cause of the error. That is +// to say, UnwrapFully returns the final error in the error chain. +// UnwrapFully returns nil if err is nil, but otherwise will not return nil. +func UnwrapFully(err error) error { + if err == nil { + return nil } - for err != nil { - cause, ok := err.(causer) //nolint:errorlint - if !ok { + var cause error + for { + if cause = errors.Unwrap(err); cause == nil { break } - err = cause.Cause() + err = cause } return err } diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index 93708ed77..4fe57b2e3 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -11,13 +11,10 @@ package errz import ( "context" "errors" - "fmt" - "log/slog" - "go.uber.org/multierr" ) -// Err annotates err with a stack trace at the point WithStack was called. +// Err annotates err with a stack trace at the point Err was called. // If err is nil, Err returns nil. func Err(err error) error { if err == nil { @@ -25,7 +22,7 @@ func Err(err error) error { } return &withStack{ err, - nil, + "", callers(), } } @@ -39,35 +36,6 @@ var Combine = multierr.Combine // Errors is documented by multierr.Errors. var Errors = multierr.Errors -// logValue return a slog.Value for err. -func logValue(err error) slog.Value { - if err == nil { - return slog.Value{} - } - - c := Cause(err) - if c == nil { - // Shouldn't happen - return slog.Value{} - } - - attrs := []slog.Attr{slog.String("msg", err.Error())} - if !errors.Is(c, err) { - attrs = append(attrs, - slog.String("cause", c.Error()), - slog.String("type", fmt.Sprintf("%T", c)), - ) - - // If there's a cause c, "type" will be the type of c. - } else { - // If there's no cause, "type" will be the type of err. - // It's a bit wonky, but probably the most useful thing to show. - attrs = append(attrs, slog.String("type", fmt.Sprintf("%T", err))) - } - - return slog.GroupValue(attrs...) -} - // IsErrContext returns true if err is context.Canceled or context.DeadlineExceeded. func IsErrContext(err error) bool { if err == nil { @@ -81,8 +49,11 @@ func IsErrContext(err error) bool { return false } -// Return returns t, and err wrapped with [errz.Err]. -// This is useful for the common case of returning a value and an error. +// Return returns t with err wrapped via [errz.Err]. +// This is useful for the common case of returning a value and +// an error from a function. +// +// written, err = errz.Return(io.Copy(w, r)) func Return[T any](t T, err error) (T, error) { return t, Err(err) } @@ -90,43 +61,47 @@ func Return[T any](t T, err error) (T, error) { // As is a convenience wrapper around errors.As. // // _, err := os.Open("non-existing") -// ok, pathErr := errz.As[*fs.PathError](err) +// pathErr, ok := errz.As[*fs.PathError](err) // require.True(t, ok) // require.Equal(t, "non-existing", pathErr.Path) // -// Under the covers, As delegates to errors.As. -func As[E error](err error) (bool, E) { +// If err is nil, As returns false. +func As[E error](err error) (E, bool) { var target E + if err == nil { + return target, false + } + if errors.As(err, &target) { - return true, target + return target, true } - return false, target + return target, false } -// IsType returns true if err, or an error in its tree, if of type E. +// Has returns true if err, or an error in its error tree, if of type E. // // _, err := os.Open("non-existing") -// isPathErr := errz.IsType[*fs.PathError](err) +// isPathErr := errz.Has[*fs.PathError](err) // -// Under the covers, IsType uses errors.As. -func IsType[E error](err error) bool { +// If err is nil, Has returns false. +func Has[E error](err error) bool { + if err == nil { + return false + } var target E return errors.As(err, &target) } -// Tree returns a slice of all the errors in err's tree. -func Tree(err error) []error { +// Chain returns a slice of all the errors in err's tree. +func Chain(err error) []error { if err == nil { return nil } var errs []error - for { + for err != nil { errs = append(errs, err) err = errors.Unwrap(err) - if err == nil { - break - } } return errs diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index f90cbef0b..85d268d94 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -7,16 +7,26 @@ import ( "io/fs" "net/url" "os" + "runtime/debug" "testing" "github.com/stretchr/testify/require" "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/stringz" ) +func TestErrEmpty(t *testing.T) { + err := errz.New("") + gotMsg := err.Error() + require.Equal(t, "", gotMsg) + gotCause := errz.UnwrapFully(err) + require.NotNil(t, gotCause) + + t.Log(gotMsg) +} + func TestIs(t *testing.T) { err := errz.Wrap(sql.ErrNoRows, "wrap") @@ -24,26 +34,26 @@ func TestIs(t *testing.T) { require.True(t, errors.Is(err, sql.ErrNoRows)) } -func TestWrapCauseAs(t *testing.T) { +func TestUnwrapFully(t *testing.T) { var originalErr error //nolint:gosimple - originalErr = &CustomError{msg: "huzzah"} + originalErr = &customError{msg: "huzzah"} - err := errz.Wrap(errz.Wrap(originalErr, "wrap"), "wrap") - require.Equal(t, "wrap: wrap: huzzah", err.Error()) + err := errz.Wrap(errz.Wrap(originalErr, "wrap1"), "wrap2") + require.Equal(t, "wrap2: wrap1: huzzah", err.Error()) - var gotCustomErr *CustomError + var gotCustomErr *customError require.True(t, errors.As(err, &gotCustomErr)) require.Equal(t, "huzzah", gotCustomErr.msg) - gotUnwrap := errz.Cause(err) - require.Equal(t, *originalErr.(*CustomError), *gotUnwrap.(*CustomError)) //nolint:errorlint + gotUnwrap := errz.UnwrapFully(err) + require.Equal(t, *originalErr.(*customError), *gotUnwrap.(*customError)) //nolint:errorlint } -type CustomError struct { +type customError struct { msg string } -func (e *CustomError) Error() string { +func (e *customError) Error() string { return e.msg } @@ -51,13 +61,13 @@ func TestLogError_LogValue(t *testing.T) { log := lgt.New(t) nakedErr := sql.ErrNoRows - log.Debug("naked", lga.Err, nakedErr) + log.Debug("naked", "err", nakedErr) zErr := errz.Err(nakedErr) - log.Debug("via errz.Err", lga.Err, zErr) + log.Debug("via errz.Err", "err", zErr) wrapErr := errz.Wrap(nakedErr, "wrap me") - log.Debug("via errz.Wrap", lga.Err, wrapErr) + log.Debug("via errz.Wrap", "err", wrapErr) } func TestIsErrNotExist(t *testing.T) { @@ -103,10 +113,10 @@ func TestIsType(t *testing.T) { require.Error(t, err) t.Logf("err: %T %v", err, err) - got := errz.IsType[*fs.PathError](err) + got := errz.Has[*fs.PathError](err) require.True(t, got) - got = errz.IsType[*url.Error](err) + got = errz.Has[*url.Error](err) require.False(t, got) } @@ -116,7 +126,7 @@ func TestAs(t *testing.T) { require.Error(t, err) t.Logf("err: %T %v", err, err) - ok, pathErr := errz.As[*fs.PathError](err) + pathErr, ok := errz.As[*fs.PathError](err) require.True(t, ok) require.NotNil(t, pathErr) require.Equal(t, fp, pathErr.Path) @@ -129,5 +139,7 @@ func TestStackTrace(t *testing.T) { require.True(t, ok) require.NotNil(t, tracer) tr := tracer.StackTrace() - t.Logf("stack trace:%+v", tr) + t.Logf("stack trace:\n%+v", tr) + + debug.PrintStack() } diff --git a/libsq/core/errz/errz_types.go b/libsq/core/errz/errz_types.go index 43b629324..82cc7c4f0 100644 --- a/libsq/core/errz/errz_types.go +++ b/libsq/core/errz/errz_types.go @@ -6,6 +6,8 @@ import ( // NotExistError indicates that a DB object, such // as a table, does not exist. +// +// REVISIT: Consider moving NotExistError to libsq/driver? type NotExistError struct { error } @@ -34,6 +36,9 @@ func IsErrNotExist(err error) bool { // NoDataError indicates that there's no data, e.g. an empty document. // This is subtly different to NotExistError, which would indicate that // the document doesn't exist. +// +// REVISIT: Consider moving NoDataError to libsq/driver? +// REVISIT: Consider renaming NoDataError to EmptyDataError? type NoDataError struct { error } diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index 3241a6d40..df39003b8 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -63,7 +63,7 @@ func (f Frame) name() string { // Format accepts flags that alter the printing of some verbs, as follows: // // %+s function name and path of source file relative to the compile time -// GOPATH separated by \nSPSP (\nSPSP) +// GOPATH separated by \n\t (\n\t) // %+v equivalent to %+s:%d func (f Frame) Format(s fmt.State, verb rune) { switch verb { @@ -71,7 +71,7 @@ func (f Frame) Format(s fmt.State, verb rune) { switch { case s.Flag('+'): _, _ = io.WriteString(s, f.name()) - _, _ = io.WriteString(s, "\n ") + _, _ = io.WriteString(s, "\n\t") _, _ = io.WriteString(s, f.file()) default: _, _ = io.WriteString(s, path.Base(f.file())) @@ -221,6 +221,7 @@ func Stack(err error) *StackTrace { // Stacks returns any stack trace(s) attached to err. If err // has been wrapped more than once, there may be multiple stack traces. // Generally speaking, the final stack trace is the most interesting. +// The returned StackTrace items can be printed using fmt "%+v". func Stacks(err error) []*StackTrace { if err == nil { return nil @@ -236,8 +237,8 @@ func Stacks(err error) []*StackTrace { switch err := err.(type) { //nolint:errorlint case *withStack: stacks = append(stacks, err.StackTrace()) - case *fundamental: - stacks = append(stacks, err.StackTrace()) + //case *fundamental: + // stacks = append(stacks, err.StackTrace()) default: } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index c31037faa..2d5651a12 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -311,7 +311,7 @@ func (dl *Download) do(req *http.Request) (*http.Response, error) { // // We want to trim off that `GET "URL"` prefix, but we only do that if // there's a wrapped error beneath (which should be the case). - if errz.IsType[*url.Error](err) && errors.Is(err, context.DeadlineExceeded) { + if errz.Has[*url.Error](err) && errors.Is(err, context.DeadlineExceeded) { if e := errors.Unwrap(err); e != nil { err = e } diff --git a/libsq/core/ioz/lockfile/lockfile.go b/libsq/core/ioz/lockfile/lockfile.go index af267a555..9997a9cc4 100644 --- a/libsq/core/ioz/lockfile/lockfile.go +++ b/libsq/core/ioz/lockfile/lockfile.go @@ -64,7 +64,7 @@ func (l Lockfile) Lock(ctx context.Context, timeout time.Duration) error { // log.Debug("Failed to acquire pid lock, may retry", lga.Attempts, attempts, lga.Err, err) return err }, - errz.IsType[lockfile.TemporaryError], + errz.Has[lockfile.TemporaryError], ) elapsed := time.Since(start) diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 8976b5601..dca5950be 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -321,7 +321,7 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { } if stack.Error != nil { - errTypes := stringz.TypeNames(errz.Tree(stack.Error)...) + errTypes := stringz.TypeNames(errz.Chain(stack.Error)...) for j, typ := range errTypes { buf.WriteStringIf(!h.noColor, ansiStackErrType) buf.WriteString(typ) From ff418ee9fd69410bda0067b612641061675995e1 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 05:57:46 -0700 Subject: [PATCH 132/195] errz: renamed UnwrapFully to UnwrapChain --- libsq/core/errz/errors.go | 24 ++++++++++++------------ libsq/core/errz/errz_test.go | 6 +++--- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/libsq/core/errz/errors.go b/libsq/core/errz/errors.go index ef76af37f..984a31a46 100644 --- a/libsq/core/errz/errors.go +++ b/libsq/core/errz/errors.go @@ -37,14 +37,14 @@ // for inspection. Any error value which implements this interface // // type causer interface { -// UnwrapFully() error +// UnwrapChain() error // } // -// can be inspected by errors.UnwrapFully. errors.UnwrapFully will recursively retrieve +// can be inspected by errors.UnwrapChain. errors.UnwrapChain will recursively retrieve // the topmost error that does not implement causer, which is assumed to be // the original cause. For example: // -// switch err := errors.UnwrapFully(err).(type) { +// switch err := errors.UnwrapChain(err).(type) { // case *MyError: // // handle specifically // default: @@ -59,7 +59,7 @@ // All error values returned from this package implement fmt.Formatter and can // be formatted by the fmt package. The following verbs are supported: // -// %s print the error. If the error has a UnwrapFully it will be +// %s print the error. If the error has a UnwrapChain it will be // printed recursively. // %v see %s // %+v extended format. Each Frame of the error's StackTrace will @@ -225,9 +225,9 @@ func (w *withStack) LogValue() slog.Value { return slog.GroupValue(attrs...) } -// UnwrapFully returns the underlying *root* cause of the error, if possible. +// UnwrapChain returns the underlying *root* cause of the error, if possible. // -// Deprecated: get rid of UnwrapFully in favor of errors.Unwrap. +// Deprecated: get rid of UnwrapChain in favor of errors.Unwrap. func (w *withStack) Cause() error { return w.error } // Unwrap provides compatibility for Go 1.13 error chains. @@ -319,7 +319,7 @@ func Wrapf(err error, format string, args ...any) error { //} // //func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } -//func (w *withMessage) UnwrapFully() error { return w.cause } +//func (w *withMessage) UnwrapChain() error { return w.cause } // //// LogValue implements slog.LogValuer. //func (w *withMessage) LogValue() slog.Value { @@ -333,7 +333,7 @@ func Wrapf(err error, format string, args ...any) error { // switch verb { // case 'v': // if s.Flag('+') { -// _, _ = fmt.Fprintf(s, "%+v\n", w.UnwrapFully()) +// _, _ = fmt.Fprintf(s, "%+v\n", w.UnwrapChain()) // _, _ = io.WriteString(s, w.msg) // return // } @@ -343,10 +343,10 @@ func Wrapf(err error, format string, args ...any) error { // } //} -// UnwrapFully returns the underlying *root* cause of the error. That is -// to say, UnwrapFully returns the final error in the error chain. -// UnwrapFully returns nil if err is nil, but otherwise will not return nil. -func UnwrapFully(err error) error { +// UnwrapChain returns the underlying *root* cause of the error. That is +// to say, UnwrapChain returns the final non-nil error in the error chain. +// UnwrapChain returns nil if err is nil. +func UnwrapChain(err error) error { if err == nil { return nil } diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 85d268d94..6ea9227c2 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -21,7 +21,7 @@ func TestErrEmpty(t *testing.T) { err := errz.New("") gotMsg := err.Error() require.Equal(t, "", gotMsg) - gotCause := errz.UnwrapFully(err) + gotCause := errz.UnwrapChain(err) require.NotNil(t, gotCause) t.Log(gotMsg) @@ -34,7 +34,7 @@ func TestIs(t *testing.T) { require.True(t, errors.Is(err, sql.ErrNoRows)) } -func TestUnwrapFully(t *testing.T) { +func TestUnwrapChain(t *testing.T) { var originalErr error //nolint:gosimple originalErr = &customError{msg: "huzzah"} @@ -45,7 +45,7 @@ func TestUnwrapFully(t *testing.T) { require.True(t, errors.As(err, &gotCustomErr)) require.Equal(t, "huzzah", gotCustomErr.msg) - gotUnwrap := errz.UnwrapFully(err) + gotUnwrap := errz.UnwrapChain(err) require.Equal(t, *originalErr.(*customError), *gotUnwrap.(*customError)) //nolint:errorlint } From 4611b574b6899ad72650f2dc0cc730cd9324ef64 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 06:44:02 -0700 Subject: [PATCH 133/195] errz: more refactoring --- cli/output/tablew/errorwriter.go | 2 +- libsq/core/errz/errors.go | 117 ++++++++------------------- libsq/core/errz/errz_test.go | 9 +-- libsq/core/errz/internal_test.go | 23 ++++++ libsq/core/errz/stack.go | 86 +++++++++----------- libsq/core/lg/devlog/tint/handler.go | 2 +- 6 files changed, 98 insertions(+), 141 deletions(-) create mode 100644 libsq/core/errz/internal_test.go diff --git a/cli/output/tablew/errorwriter.go b/cli/output/tablew/errorwriter.go index 667036c6a..d1568c63a 100644 --- a/cli/output/tablew/errorwriter.go +++ b/cli/output/tablew/errorwriter.go @@ -42,7 +42,7 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { continue } - stackPrint := fmt.Sprintf("%+v", stack) + stackPrint := fmt.Sprintf("%+v", stack.Frames) stackPrint = strings.ReplaceAll(strings.TrimSpace(stackPrint), "\n\t", "\n ") if stackPrint == "" { continue diff --git a/libsq/core/errz/errors.go b/libsq/core/errz/errors.go index 984a31a46..113b841c7 100644 --- a/libsq/core/errz/errors.go +++ b/libsq/core/errz/errors.go @@ -74,16 +74,16 @@ // StackTrace() errors.StackTrace // } // -// The returned errors.StackTrace type is defined as +// The returned errors.stackTrace type is defined as // -// type StackTrace []Frame +// type stackTrace []Frame // // The Frame type represents a call site in the stack trace. Frame supports // the fmt.Formatter interface that can be used for printing information about // the stack trace of this error. For example: // // if err, ok := err.(stackTracer); ok { -// for _, f := range err.StackTrace() { +// for _, f := range err.stackTrace() { // fmt.Printf("%+s:%d\n", f, f) // } // } @@ -120,63 +120,13 @@ func Errorf(format string, args ...any) error { } } -// -//// fundamental is an error that has a message and a stack, but no caller. -//type fundamental struct { -// msg string -// *stack -//} -// -//func (f *fundamental) Error() string { return f.msg } -// -//func (f *fundamental) Format(s fmt.State, verb rune) { -// switch verb { -// case 'v': -// if s.Flag('+') { -// _, _ = io.WriteString(s, f.msg) -// f.stack.Format(s, verb) -// return -// } -// fallthrough -// case 's': -// _, _ = io.WriteString(s, f.msg) -// case 'q': -// _, _ = fmt.Fprintf(s, "{%s}", f.msg) -// } -//} -// -//var _ StackTracer = (*fundamental)(nil) -// -//// StackTrace implements StackTracer. -//func (f *fundamental) StackTrace() *StackTrace { -// if f == nil || f.stack == nil { -// return nil -// } -// -// st := f.stack.stackTrace() -// if st != nil { -// st.Error = f -// } -// return st -//} -// -//// LogValue implements slog.LogValuer. -//func (f *fundamental) LogValue() slog.Value { -// return logValue(f) -//} - type withStack struct { error msg string *stack } -var _ StackTracer = (*withStack)(nil) - -// StackTrace implements StackTracer. -// REVISIT: consider making StackTrace private, or removing -// it in favor of the Stack function. -func (w *withStack) StackTrace() *StackTrace { +func (w *withStack) stackTrace() *StackTrace { if w == nil || w.stack == nil { return nil } @@ -202,53 +152,50 @@ func (w *withStack) Error() string { return w.msg + ": " + w.error.Error() } -// LogValue implements slog.LogValuer. +// LogValue implements slog.LogValuer. It returns a slog.GroupValue, +// having attributes "msg" and "type". If the error has a cause that +// from outside this package, the cause's type is include in the +// "cause" attribute. func (w *withStack) LogValue() slog.Value { if w == nil { return slog.Value{} } - attrs := make([]slog.Attr, 2, 4) + attrs := make([]slog.Attr, 2, 3) attrs[0] = slog.String("msg", w.Error()) attrs[1] = slog.String("type", fmt.Sprintf("%T", w)) - // If there's a wrapped error, "cause" and "type" will be - // for that wrapped error. - if w.error != nil { - attrs[1] = slog.String("cause", w.error.Error()) - attrs = append(attrs, slog.String("type", fmt.Sprintf("%T", w.error))) - } else { - // If there's no wrapped error, "type" will be the type of w. - attrs[1] = slog.String("type", fmt.Sprintf("%T", w)) + if cause := w.foreignCause(); cause != nil { + attrs = append(attrs, slog.String("cause", fmt.Sprintf("%T", cause))) } return slog.GroupValue(attrs...) } -// UnwrapChain returns the underlying *root* cause of the error, if possible. -// -// Deprecated: get rid of UnwrapChain in favor of errors.Unwrap. -func (w *withStack) Cause() error { return w.error } +// foreignCause returns the first error in the chain that is +// not of type *withStack, or returns nil if no such error. +func (w *withStack) foreignCause() error { + if w == nil { + return nil + } + + inner := w.error + for { + switch v := inner.(type) { + case nil: + return v + case *withStack: + inner = v.error + default: + return v + } + } +} // Unwrap provides compatibility for Go 1.13 error chains. func (w *withStack) Unwrap() error { return w.error } -//func (f *fundamental) Format(s fmt.State, verb rune) { -// switch verb { -// case 'v': -// if s.Flag('+') { -// _, _ = io.WriteString(s, f.msg) -// f.stack.Format(s, verb) -// return -// } -// fallthrough -// case 's': -// _, _ = io.WriteString(s, f.msg) -// case 'q': -// _, _ = fmt.Fprintf(s, "{%s}", f.msg) -// } -//} - +// Format implements fmt.Formatter. func (w *withStack) Format(s fmt.State, verb rune) { switch verb { case 'v': @@ -258,7 +205,7 @@ func (w *withStack) Format(s fmt.State, verb rune) { w.stack.Format(s, verb) return } else { - _, _ = fmt.Fprintf(s, "%+v", w.Cause()) + _, _ = fmt.Fprintf(s, "%+v", w.error) w.stack.Format(s, verb) } return diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 6ea9227c2..5c5e362f6 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -57,7 +57,7 @@ func (e *customError) Error() string { return e.msg } -func TestLogError_LogValue(t *testing.T) { +func TestLogValue(t *testing.T) { log := lgt.New(t) nakedErr := sql.ErrNoRows @@ -135,11 +135,8 @@ func TestAs(t *testing.T) { func TestStackTrace(t *testing.T) { err := errz.New("huzzah") - tracer, ok := err.(errz.StackTracer) - require.True(t, ok) - require.NotNil(t, tracer) - tr := tracer.StackTrace() - t.Logf("stack trace:\n%+v", tr) + tr := errz.FinalStack(err) + t.Logf("stack trace:\n%+v", tr.Frames) debug.PrintStack() } diff --git a/libsq/core/errz/internal_test.go b/libsq/core/errz/internal_test.go new file mode 100644 index 000000000..5177fb5dd --- /dev/null +++ b/libsq/core/errz/internal_test.go @@ -0,0 +1,23 @@ +package errz + +import ( + "context" + "github.com/stretchr/testify/require" + "testing" +) + +func TestForeignCause(t *testing.T) { + err := New("boo") + + cause := err.(*withStack).foreignCause() + require.Nil(t, cause) + + err = Err(context.DeadlineExceeded) + cause = err.(*withStack).foreignCause() + require.Equal(t, context.DeadlineExceeded, cause) + + err = Err(context.DeadlineExceeded) + err = Wrap(err, "wrap") + cause = err.(*withStack).foreignCause() + require.Equal(t, context.DeadlineExceeded, cause) +} diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index df39003b8..bc5ad344b 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -79,7 +79,7 @@ func (f Frame) Format(s fmt.State, verb rune) { case 'd': _, _ = io.WriteString(s, strconv.Itoa(f.line())) case 'n': - _, _ = io.WriteString(s, funcname(f.name())) + _, _ = io.WriteString(s, funcName(f.name())) case 'v': f.Format(s, 's') _, _ = io.WriteString(s, ":") @@ -97,15 +97,20 @@ func (f Frame) MarshalText() ([]byte, error) { return []byte(fmt.Sprintf("%s %s:%d", name, f.file(), f.line())), nil } -// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). +// StackTrace contains a stack of Frames from innermost (newest) +// to outermost (oldest), as well as the error value that resulted +// in this stack trace. type StackTrace struct { // Error is the error value that resulted in this stack trace. Error error // Frames is the ordered list of frames that make up this stack trace. - Frames []Frame + Frames Frames } +// Frames is the ordered list of frames that make up a stack trace. +type Frames []Frame + // Format formats the stack of Frames according to the fmt.Formatter interface. // // %s lists source files for each Frame in the stack @@ -114,32 +119,32 @@ type StackTrace struct { // Format accepts flags that alter the printing of some verbs, as follows: // // %+v Prints filename, function, and line number for each Frame in the stack. -func (st *StackTrace) Format(s fmt.State, verb rune) { +func (fs Frames) Format(s fmt.State, verb rune) { switch verb { case 'v': switch { case s.Flag('+'): - for i, f := range st.Frames { + for i, f := range fs { if i != 0 { _, _ = io.WriteString(s, "\n") } f.Format(s, verb) } case s.Flag('#'): - fmt.Fprintf(s, "%#v", []Frame(st.Frames)) + fmt.Fprintf(s, "%#v", []Frame(fs)) default: - st.formatSlice(s, verb) + fs.formatSlice(s, verb) } case 's': - st.formatSlice(s, verb) + fs.formatSlice(s, verb) } } -// formatSlice will format this StackTrace into the given buffer as a slice of +// formatSlice will format this Frames into the given buffer as a slice of // Frame, only valid when called with '%s' or '%v'. -func (st *StackTrace) formatSlice(s fmt.State, verb rune) { +func (fs Frames) formatSlice(s fmt.State, verb rune) { _, _ = io.WriteString(s, "[") - for i, f := range st.Frames { + for i, f := range fs { if i > 0 { _, _ = io.WriteString(s, " ") } @@ -154,7 +159,7 @@ func (st *StackTrace) LogValue() slog.Value { return slog.Value{} } - return slog.StringValue(fmt.Sprintf("%+v", st)) + return slog.StringValue(fmt.Sprintf("%+v", st.Frames)) } // stack represents a stack of program counters. @@ -162,7 +167,7 @@ type stack []uintptr func (s *stack) Format(st fmt.State, verb rune) { if s == nil { - fmt.Fprintf(st, "") + fmt.Fprint(st, "") } switch verb { //nolint:gocritic case 'v': @@ -192,54 +197,27 @@ func callers() *stack { return &st } -// funcname removes the path prefix component of a function's name reported by func.Name(). -func funcname(name string) string { +// funcName removes the path prefix component of a function's name reported by func.Name(). +func funcName(name string) string { i := strings.LastIndex(name, "/") name = name[i+1:] i = strings.Index(name, ".") return name[i+1:] } -// Stack returns the last of any stack trace(s) attached to err. -// If err has been wrapped more than once, there may be multiple stack traces. -// Generally speaking, the final stack trace is the most interesting. -// The returned StackTrace can be printed using fmt "%+v". -func Stack(err error) *StackTrace { - if err == nil { - return nil - } - - stacks := Stacks(err) - if len(stacks) == 0 { - return nil - } - - // Return the final element of the slice - return stacks[len(stacks)-1] -} - // Stacks returns any stack trace(s) attached to err. If err // has been wrapped more than once, there may be multiple stack traces. // Generally speaking, the final stack trace is the most interesting. -// The returned StackTrace items can be printed using fmt "%+v". +// The returned StackTrace.Frames can be printed using fmt "%+v". func Stacks(err error) []*StackTrace { if err == nil { return nil } var stacks []*StackTrace - - for { - if err == nil { - break - } - - switch err := err.(type) { //nolint:errorlint - case *withStack: - stacks = append(stacks, err.StackTrace()) - //case *fundamental: - // stacks = append(stacks, err.StackTrace()) - default: + for err != nil { + if ez, ok := err.(*withStack); ok { + stacks = append(stacks, ez.stackTrace()) } err = errors.Unwrap(err) @@ -248,6 +226,18 @@ func Stacks(err error) []*StackTrace { return stacks } -type StackTracer interface { - StackTrace() *StackTrace +// FinalStack returns the last of any stack trace(s) attached to err. +// It is a convenience function to return the last element of errz.Stacks(err). +func FinalStack(err error) *StackTrace { + if err == nil { + return nil + } + + stacks := Stacks(err) + if len(stacks) == 0 { + return nil + } + + // Return the final element of the slice + return stacks[len(stacks)-1] } diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index dca5950be..679ccc874 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -310,7 +310,7 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { continue } - stackPrint := fmt.Sprintf("%+v", stack) + stackPrint := fmt.Sprintf("%+v", stack.Frames) stackPrint = strings.ReplaceAll(strings.TrimSpace(stackPrint), "\n\t", "\n ") if stackPrint == "" { continue From 0e0da30e2e64bb7d96d4bf5119cdae99d5c61496 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 07:53:13 -0700 Subject: [PATCH 134/195] errz: almost done --- libsq/core/errz/errors.go | 267 +++++++++---------------------- libsq/core/errz/errz.go | 2 +- libsq/core/errz/errz_test.go | 14 +- libsq/core/errz/internal_test.go | 6 +- libsq/core/errz/stack.go | 51 ++++-- 5 files changed, 121 insertions(+), 219 deletions(-) diff --git a/libsq/core/errz/errors.go b/libsq/core/errz/errors.go index 113b841c7..75ccc5f2f 100644 --- a/libsq/core/errz/errors.go +++ b/libsq/core/errz/errors.go @@ -1,171 +1,86 @@ -// Package errz provides simple error handling primitives. -// -// FIXME: update docs -// -// The traditional error handling idiom in Go is roughly akin to -// -// if err != nil { -// return err -// } -// -// which when applied recursively up the call stack results in error reports -// without context or debugging information. The errors package allows -// programmers to add context to the failure path in their code in a way -// that does not destroy the original value of the error. -// -// # Adding context to an error -// -// The errors.Wrap function returns a new error that adds context to the -// original error by recording a stack trace at the point Wrap is called, -// together with the supplied message. For example -// -// _, err := ioutil.ReadAll(r) -// if err != nil { -// return errors.Wrap(err, "read failed") -// } -// -// If additional control is required, the errors.WithStack and -// errors.WithMessage functions destructure errors.Wrap into its component -// operations: annotating an error with a stack trace and with a message, -// respectively. -// -// # Retrieving the cause of an error -// -// Using errors.Wrap constructs a stack of errors, adding context to the -// preceding error. Depending on the nature of the error it may be necessary -// to reverse the operation of errors.Wrap to retrieve the original error -// for inspection. Any error value which implements this interface -// -// type causer interface { -// UnwrapChain() error -// } -// -// can be inspected by errors.UnwrapChain. errors.UnwrapChain will recursively retrieve -// the topmost error that does not implement causer, which is assumed to be -// the original cause. For example: -// -// switch err := errors.UnwrapChain(err).(type) { -// case *MyError: -// // handle specifically -// default: -// // unknown error -// } -// -// Although the causer interface is not exported by this package, it is -// considered a part of its stable public interface. -// -// # Formatted printing of errors -// -// All error values returned from this package implement fmt.Formatter and can -// be formatted by the fmt package. The following verbs are supported: -// -// %s print the error. If the error has a UnwrapChain it will be -// printed recursively. -// %v see %s -// %+v extended format. Each Frame of the error's StackTrace will -// be printed in detail. -// -// # Retrieving the stack trace of an error or wrapper -// -// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are -// invoked. This information can be retrieved with the following interface: -// -// type stackTracer interface { -// StackTrace() errors.StackTrace -// } -// -// The returned errors.stackTrace type is defined as -// -// type stackTrace []Frame -// -// The Frame type represents a call site in the stack trace. Frame supports -// the fmt.Formatter interface that can be used for printing information about -// the stack trace of this error. For example: -// -// if err, ok := err.(stackTracer); ok { -// for _, f := range err.stackTrace() { -// fmt.Printf("%+s:%d\n", f, f) -// } -// } -// -// Although the stackTracer interface is not exported by this package, it is -// considered a part of its stable public interface. -// -// See the documentation for Frame.Format for more details. package errz +// ACKNOWLEDGEMENT: The code in this file has its origins +// in Dave Cheney's pkg/errors. + import ( "errors" "fmt" "io" "log/slog" + "strings" ) -// New returns an error with the supplied message. -// New also records the stack trace at the point it was called. +// New returns an error with the supplied message, recording the +// stack trace at the point it was called. func New(message string) error { - return &withStack{ - msg: message, - stack: callers(), - } + return &errz{stack: callers(), msg: message} } -// Errorf formats according to a format specifier and returns the string -// as a value that satisfies error. -// Errorf also records the stack trace at the point it was called. +// Errorf works like [fmt.Errorf], but it also records the stack trace +// at the point it was called. If the format string includes the %w verb, +// [fmt.Errorf] is first called to construct the error, which is then wrapped. func Errorf(format string, args ...any) error { - return &withStack{ - msg: fmt.Sprintf(format, args...), - stack: callers(), + if strings.Contains(format, "%w") { + return &errz{stack: callers(), error: fmt.Errorf(format, args...)} } + return &errz{stack: callers(), msg: fmt.Sprintf(format, args...)} + + //return Err(fmt.Errorf(format, args...)) + //return &errz{ + // // REVISIT: should we delegate to fmt.Errorf instead? + // msg: fmt.Sprintf(format, args...), + // stack: callers(), + //} } -type withStack struct { +// errz is the error implementation used by this package. +type errz struct { error msg string *stack } -func (w *withStack) stackTrace() *StackTrace { - if w == nil || w.stack == nil { +func (e *errz) stackTrace() *StackTrace { + if e == nil || e.stack == nil { return nil } - st := w.stack.stackTrace() + st := e.stack.stackTrace() if st != nil { - st.Error = w + st.Error = e } return st } // Error implements stdlib error interface. -func (w *withStack) Error() string { - if w.msg == "" { - if w.error == nil { +func (e *errz) Error() string { + if e.msg == "" { + if e.error == nil { return "" } - return w.error.Error() + return e.error.Error() } - if w.error == nil { - return w.msg + if e.error == nil { + return e.msg } - return w.msg + ": " + w.error.Error() + return e.msg + ": " + e.error.Error() } // LogValue implements slog.LogValuer. It returns a slog.GroupValue, // having attributes "msg" and "type". If the error has a cause that -// from outside this package, the cause's type is include in the +// from outside this package, the cause's type is included in a // "cause" attribute. -func (w *withStack) LogValue() slog.Value { - if w == nil { +func (e *errz) LogValue() slog.Value { + if e == nil { return slog.Value{} } attrs := make([]slog.Attr, 2, 3) - attrs[0] = slog.String("msg", w.Error()) - attrs[1] = slog.String("type", fmt.Sprintf("%T", w)) + attrs[0] = slog.String("msg", e.Error()) + attrs[1] = slog.String("type", fmt.Sprintf("%T", e)) - if cause := w.foreignCause(); cause != nil { + if cause := e.foreignCause(); cause != nil { attrs = append(attrs, slog.String("cause", fmt.Sprintf("%T", cause))) } @@ -173,122 +88,84 @@ func (w *withStack) LogValue() slog.Value { } // foreignCause returns the first error in the chain that is -// not of type *withStack, or returns nil if no such error. -func (w *withStack) foreignCause() error { - if w == nil { +// not of type *errz, or returns nil if no such error. +func (e *errz) foreignCause() error { + if e == nil { return nil } - inner := w.error - for { - switch v := inner.(type) { - case nil: - return v - case *withStack: + inner := e.error + for inner != nil { + // Note: don't use errors.As here; we want the direct type assertion. + if v, ok := inner.(*errz); ok { inner = v.error - default: - return v + continue } + return inner } + return nil } // Unwrap provides compatibility for Go 1.13 error chains. -func (w *withStack) Unwrap() error { return w.error } +func (e *errz) Unwrap() error { return e.error } // Format implements fmt.Formatter. -func (w *withStack) Format(s fmt.State, verb rune) { +func (e *errz) Format(s fmt.State, verb rune) { switch verb { case 'v': if s.Flag('+') { - if w.error == nil { - _, _ = io.WriteString(s, w.msg) - w.stack.Format(s, verb) + if e.error == nil { + _, _ = io.WriteString(s, e.msg) + e.stack.Format(s, verb) return } else { - _, _ = fmt.Fprintf(s, "%+v", w.error) - w.stack.Format(s, verb) + _, _ = fmt.Fprintf(s, "%+v", e.error) + e.stack.Format(s, verb) } return } fallthrough case 's': - if w.error == nil { - _, _ = io.WriteString(s, w.msg) + if e.error == nil { + _, _ = io.WriteString(s, e.msg) return } - _, _ = io.WriteString(s, w.Error()) + _, _ = io.WriteString(s, e.Error()) case 'q': - if w.error == nil { - _, _ = fmt.Fprintf(s, "{%s}", w.msg) + if e.error == nil { + _, _ = fmt.Fprintf(s, "{%s}", e.msg) return } - _, _ = fmt.Fprintf(s, "{%s}", w.Error()) + _, _ = fmt.Fprintf(s, "{%s}", e.Error()) } } // Wrap returns an error annotating err with a stack trace // at the point Wrap is called, and the supplied message. -// If err is nil, Wrap returns nil. +// If err is nil, Wrap returns nil. See also: Wrapf. func Wrap(err error, message string) error { if err == nil { return nil } - return &withStack{ - err, - message, - callers(), - } + return &errz{stack: callers(), error: err, msg: message} } // Wrapf returns an error annotating err with a stack trace -// at the point Wrapf is called, and the format specifier. -// If err is nil, Wrapf returns nil. +// at the point Wrapf is called. Wrapf will panic if format +// includes the %w verb: use errz.Errorf for that situation. +// If err is nil, Wrapf returns nil. See also: Wrap, Errorf. func Wrapf(err error, format string, args ...any) error { if err == nil { return nil } - //err = &withMessage{ - // cause: err, - // msg: fmt.Sprintf(format, args...), - //} - return &withStack{ - err, - fmt.Sprintf(format, args...), - callers(), + + if strings.Contains(format, "%w") { + panic("errz.Wrapf does not support %w verb: use errz.Errorf instead") } -} -// -//type withMessage struct { //nolint:errname -// cause error -// msg string -//} -// -//func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } -//func (w *withMessage) UnwrapChain() error { return w.cause } -// -//// LogValue implements slog.LogValuer. -//func (w *withMessage) LogValue() slog.Value { -// return logValue(w) -//} -// -//// Unwrap provides compatibility for Go 1.13 error chains. -//func (w *withMessage) Unwrap() error { return w.cause } -// -//func (w *withMessage) Format(s fmt.State, verb rune) { -// switch verb { -// case 'v': -// if s.Flag('+') { -// _, _ = fmt.Fprintf(s, "%+v\n", w.UnwrapChain()) -// _, _ = io.WriteString(s, w.msg) -// return -// } -// fallthrough -// case 's', 'q': -// _, _ = io.WriteString(s, w.Error()) -// } -//} + return &errz{error: err, msg: fmt.Sprintf(format, args...), stack: callers()} +} // UnwrapChain returns the underlying *root* cause of the error. That is // to say, UnwrapChain returns the final non-nil error in the error chain. diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index 4fe57b2e3..a4a22d6f0 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -20,7 +20,7 @@ func Err(err error) error { if err == nil { return nil } - return &withStack{ + return &errz{ err, "", callers(), diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 5c5e362f6..7644c8d94 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -7,7 +7,6 @@ import ( "io/fs" "net/url" "os" - "runtime/debug" "testing" "github.com/stretchr/testify/require" @@ -108,7 +107,7 @@ func TestIsErrNoData(t *testing.T) { require.Equal(t, "me doesn't exist", err.Error()) } -func TestIsType(t *testing.T) { +func TestHas(t *testing.T) { _, err := os.Open(stringz.Uniq32() + "-non-existing") require.Error(t, err) t.Logf("err: %T %v", err, err) @@ -133,10 +132,13 @@ func TestAs(t *testing.T) { } func TestStackTrace(t *testing.T) { - err := errz.New("huzzah") + e1 := errz.New("inner") + e2 := errz.Wrap(e1, "wrap") - tr := errz.FinalStack(err) - t.Logf("stack trace:\n%+v", tr.Frames) + gotStacks := errz.Stacks(e2) + require.Len(t, gotStacks, 2) - debug.PrintStack() + gotFinalStack := errz.LastStack(e2) + require.NotNil(t, gotFinalStack) + require.Equal(t, gotStacks[len(gotStacks)-1], gotFinalStack) } diff --git a/libsq/core/errz/internal_test.go b/libsq/core/errz/internal_test.go index 5177fb5dd..da4dc81fd 100644 --- a/libsq/core/errz/internal_test.go +++ b/libsq/core/errz/internal_test.go @@ -9,15 +9,15 @@ import ( func TestForeignCause(t *testing.T) { err := New("boo") - cause := err.(*withStack).foreignCause() + cause := err.(*errz).foreignCause() require.Nil(t, cause) err = Err(context.DeadlineExceeded) - cause = err.(*withStack).foreignCause() + cause = err.(*errz).foreignCause() require.Equal(t, context.DeadlineExceeded, cause) err = Err(context.DeadlineExceeded) err = Wrap(err, "wrap") - cause = err.(*withStack).foreignCause() + cause = err.(*errz).foreignCause() require.Equal(t, context.DeadlineExceeded, cause) } diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index bc5ad344b..66dd8b7db 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -205,10 +205,12 @@ func funcName(name string) string { return name[i+1:] } -// Stacks returns any stack trace(s) attached to err. If err -// has been wrapped more than once, there may be multiple stack traces. -// Generally speaking, the final stack trace is the most interesting. -// The returned StackTrace.Frames can be printed using fmt "%+v". +// Stacks returns all stack trace(s) attached to err. If err has been wrapped +// more than once, there may be multiple stack traces. Generally speaking, the +// final stack trace is the most interesting; you can use [errz.LastStack] if +// you're just interested in that one. +// +// The returned [StackTrace.Frames] can be printed using fmt "%+v". func Stacks(err error) []*StackTrace { if err == nil { return nil @@ -216,8 +218,11 @@ func Stacks(err error) []*StackTrace { var stacks []*StackTrace for err != nil { - if ez, ok := err.(*withStack); ok { - stacks = append(stacks, ez.stackTrace()) + if ez, ok := err.(*errz); ok { + st := ez.stackTrace() + if st != nil { + stacks = append(stacks, st) + } } err = errors.Unwrap(err) @@ -226,18 +231,36 @@ func Stacks(err error) []*StackTrace { return stacks } -// FinalStack returns the last of any stack trace(s) attached to err. -// It is a convenience function to return the last element of errz.Stacks(err). -func FinalStack(err error) *StackTrace { +// LastStack returns the last of any stack trace(s) attached to err, or nil. +// Contrast with [errz.Stacks], which returns all stack traces attached +// to any error in the chain. But if you only want to examine one stack, +// the final stack trace is usually the most interesting, which is why this +// function exists. +// +// The returned StackTrace.Frames can be printed using fmt "%+v". +func LastStack(err error) *StackTrace { if err == nil { return nil } - stacks := Stacks(err) - if len(stacks) == 0 { - return nil + var ez *errz + var ok bool + for err != nil { + ez, ok = err.(*errz) + if !ok || ez == nil { + return nil + } + + if ez.error == nil { + return ez.stackTrace() + } + + if _, ok = ez.error.(*errz); !ok { + return ez.stackTrace() + } + + err = ez.error } - // Return the final element of the slice - return stacks[len(stacks)-1] + return nil } From 0b88ff892a3526c60e8bb2dedefd2cdb902c154f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 08:17:58 -0700 Subject: [PATCH 135/195] errz: almost done --- libsq/core/errz/errors.go | 10 ++-------- libsq/core/errz/errz_test.go | 28 +++++++++++++++++++++------- libsq/core/lg/devlog/tint/handler.go | 19 ++++--------------- 3 files changed, 27 insertions(+), 30 deletions(-) diff --git a/libsq/core/errz/errors.go b/libsq/core/errz/errors.go index 75ccc5f2f..c8ff5a7c2 100644 --- a/libsq/core/errz/errors.go +++ b/libsq/core/errz/errors.go @@ -19,19 +19,13 @@ func New(message string) error { // Errorf works like [fmt.Errorf], but it also records the stack trace // at the point it was called. If the format string includes the %w verb, -// [fmt.Errorf] is first called to construct the error, which is then wrapped. +// [fmt.Errorf] is first called to construct the error, and then the +// returned error is again wrapped to record the stack trace. func Errorf(format string, args ...any) error { if strings.Contains(format, "%w") { return &errz{stack: callers(), error: fmt.Errorf(format, args...)} } return &errz{stack: callers(), msg: fmt.Sprintf(format, args...)} - - //return Err(fmt.Errorf(format, args...)) - //return &errz{ - // // REVISIT: should we delegate to fmt.Errorf instead? - // msg: fmt.Sprintf(format, args...), - // stack: callers(), - //} } // errz is the error implementation used by this package. diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 7644c8d94..77727b57f 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -22,15 +22,23 @@ func TestErrEmpty(t *testing.T) { require.Equal(t, "", gotMsg) gotCause := errz.UnwrapChain(err) require.NotNil(t, gotCause) - - t.Log(gotMsg) } -func TestIs(t *testing.T) { - err := errz.Wrap(sql.ErrNoRows, "wrap") - - require.Equal(t, "wrap: "+sql.ErrNoRows.Error(), err.Error()) - require.True(t, errors.Is(err, sql.ErrNoRows)) +func TestErrorf(t *testing.T) { + err := errz.Errorf("hello %s", "world") + require.Equal(t, "hello world", err.Error()) + chain := errz.Chain(err) + require.Len(t, chain, 1) + + err2 := errz.Errorf("wrap %w", err) + require.Equal(t, "wrap hello world", err2.Error()) + chain2 := errz.Chain(err2) + + // chain2 should have length 3: + // - the original "hello world"; + // - the wrapping error from fmt.Errorf to handle the %w verb; + // - the final outer wrapper that errz.Errorf added to the fmt.Errorf error. + require.Len(t, chain2, 3) } func TestUnwrapChain(t *testing.T) { @@ -131,6 +139,12 @@ func TestAs(t *testing.T) { require.Equal(t, fp, pathErr.Path) } +func TestIs(t *testing.T) { + err := errz.Wrap(sql.ErrNoRows, "wrap") + require.Equal(t, "wrap: "+sql.ErrNoRows.Error(), err.Error()) + require.True(t, errors.Is(err, sql.ErrNoRows)) +} + func TestStackTrace(t *testing.T) { e1 := errz.New("inner") e2 := errz.Wrap(e1, "wrap") diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 679ccc874..d38d0fab2 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -295,10 +295,11 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { } var stacks []*errz.StackTrace for _, attr := range attrs { - v := attr.Value.Any() - switch v := v.(type) { + switch v := attr.Value.Any().(type) { case *errz.StackTrace: - stacks = append(stacks, v) + if v != nil { + stacks = append(stacks, v) + } case []*errz.StackTrace: stacks = append(stacks, v...) } @@ -338,18 +339,6 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { buf.WriteStringIf(!h.noColor, ansiStackErr) buf.WriteString(stack.Error.Error()) buf.WriteStringIf(!h.noColor, ansiReset) - buf.WriteByte(' ') - //buf.WriteStringIf(!h.noColor, ansiFaint) - //// Now we'll print the type of the error. - //buf.WriteString(fmt.Sprintf("%T", stack.Error)) - //if i == len(stacks)-1 { - // // If we're on the final stack, and there's a cause underneath, - // // then we print that type too. - // if cause := errors.Unwrap(stack.Error); cause != nil { - // buf.WriteString(fmt.Sprintf(" -> %T", cause)) - // } - //} - //buf.WriteStringIf(!h.noColor, ansiResetFaint) buf.WriteByte('\n') } lines := strings.Split(stackPrint, "\n") From 70925ba78d19ca3ae8998b46ba35443597ebe178 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 09:21:45 -0700 Subject: [PATCH 136/195] errz: multi errs basic stuff working --- cli/cmd_x.go | 4 +- drivers/json/json.go | 2 +- drivers/sqlserver/sqlserver.go | 2 +- libsq/core/errz/errors.go | 6 +- libsq/core/errz/errz.go | 12 +--- libsq/core/errz/internal_test.go | 8 +-- libsq/core/errz/multi.go | 118 +++++++++++++++++++++++++++++++ libsq/core/errz/multi_test.go | 48 +++++++++++++ libsq/core/ioz/download/cache.go | 2 +- 9 files changed, 181 insertions(+), 21 deletions(-) create mode 100644 libsq/core/errz/multi.go create mode 100644 libsq/core/errz/multi_test.go diff --git a/cli/cmd_x.go b/cli/cmd_x.go index f6eddeb5d..1a1755312 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -188,7 +188,9 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { switch { case len(h.Errors) > 0: - return errz.Wrap(h.Errors[0], "huzzah") + err1 := h.Errors[0] + err2 := errz.New("another err") + return errz.Combine(err1, err2) case len(h.WriteErrors) > 0: return h.WriteErrors[0] case len(h.CachedFiles) > 0: diff --git a/drivers/json/json.go b/drivers/json/json.go index a2f229720..f3879e3a5 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -217,5 +217,5 @@ func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou func (g *grip) Close() error { g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - return errz.Combine(g.impl.Close(), g.clnup.Run()) + return errz.Append(g.impl.Close(), g.clnup.Run()) } diff --git a/drivers/sqlserver/sqlserver.go b/drivers/sqlserver/sqlserver.go index 6e7ac72d9..52d29f0a5 100644 --- a/drivers/sqlserver/sqlserver.go +++ b/drivers/sqlserver/sqlserver.go @@ -742,7 +742,7 @@ func newStmtExecFunc(stmt *sql.Stmt, db sqlz.DB, tbl string) driver.StmtExecFunc idErr := setIdentityInsert(ctx, db, tbl, true) if idErr != nil { - return 0, errz.Combine(errw(err), idErr) + return 0, errz.Append(errw(err), idErr) } res, err = stmt.ExecContext(ctx, args...) diff --git a/libsq/core/errz/errors.go b/libsq/core/errz/errors.go index c8ff5a7c2..02f0958a5 100644 --- a/libsq/core/errz/errors.go +++ b/libsq/core/errz/errors.go @@ -74,16 +74,16 @@ func (e *errz) LogValue() slog.Value { attrs[0] = slog.String("msg", e.Error()) attrs[1] = slog.String("type", fmt.Sprintf("%T", e)) - if cause := e.foreignCause(); cause != nil { + if cause := e.alienCause(); cause != nil { attrs = append(attrs, slog.String("cause", fmt.Sprintf("%T", cause))) } return slog.GroupValue(attrs...) } -// foreignCause returns the first error in the chain that is +// alienCause returns the first error in the chain that is // not of type *errz, or returns nil if no such error. -func (e *errz) foreignCause() error { +func (e *errz) alienCause() error { if e == nil { return nil } diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index a4a22d6f0..5ad94546f 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -11,7 +11,6 @@ package errz import ( "context" "errors" - "go.uber.org/multierr" ) // Err annotates err with a stack trace at the point Err was called. @@ -27,21 +26,14 @@ func Err(err error) error { } } -// Append is documented by multierr.Append. -var Append = multierr.Append - -// Combine is documented by multierr.Combine. -var Combine = multierr.Combine - -// Errors is documented by multierr.Errors. -var Errors = multierr.Errors - // IsErrContext returns true if err is context.Canceled or context.DeadlineExceeded. func IsErrContext(err error) bool { if err == nil { return false } + errors.Join() + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return true } diff --git a/libsq/core/errz/internal_test.go b/libsq/core/errz/internal_test.go index da4dc81fd..f7ba89f83 100644 --- a/libsq/core/errz/internal_test.go +++ b/libsq/core/errz/internal_test.go @@ -6,18 +6,18 @@ import ( "testing" ) -func TestForeignCause(t *testing.T) { +func TestAlienCause(t *testing.T) { err := New("boo") - cause := err.(*errz).foreignCause() + cause := err.(*errz).alienCause() require.Nil(t, cause) err = Err(context.DeadlineExceeded) - cause = err.(*errz).foreignCause() + cause = err.(*errz).alienCause() require.Equal(t, context.DeadlineExceeded, cause) err = Err(context.DeadlineExceeded) err = Wrap(err, "wrap") - cause = err.(*errz).foreignCause() + cause = err.(*errz).alienCause() require.Equal(t, context.DeadlineExceeded, cause) } diff --git a/libsq/core/errz/multi.go b/libsq/core/errz/multi.go new file mode 100644 index 000000000..ffba76909 --- /dev/null +++ b/libsq/core/errz/multi.go @@ -0,0 +1,118 @@ +package errz + +// These multi-error functions delegate to go.uber.org/multierr. That +// package was in use before stdlib introduced the errors.Join function. +// It's possible, maybe even desirable, to refactor these functions +// to use stdlib errors instead. + +import "go.uber.org/multierr" + +// Append appends the given errors together. Either value may be nil. +// +// This function is a specialization of Combine for the common case where +// there are only two errors. +// +// err = errz.Append(reader.Close(), writer.Close()) +// +// The following pattern may also be used to record failure of deferred +// operations without losing information about the original error. +// +// func doSomething(..) (err error) { +// f := acquireResource() +// defer func() { +// err = errz.Append(err, f.Close()) +// }() +// +// Note that the variable MUST be a named return to append an error to it from +// the defer statement. +func Append(left error, right error) error { + switch { + case left == nil: + return Err(right) + case right == nil: + return Err(left) + } + + return Err(multierr.Append(left, right)) +} + +// Combine combines the passed errors into a single error. +// +// If zero arguments were passed or if all items are nil, a nil error is +// returned. +// +// Combine(nil, nil) // == nil +// +// If only a single error was passed, it is returned as-is. +// +// Combine(err) // == err +// +// Combine skips over nil arguments so this function may be used to combine +// together errors from operations that fail independently of each other. +// +// errz.Combine( +// reader.Close(), +// writer.Close(), +// pipe.Close(), +// ) +// +// If any of the passed errors is a multierr error, it will be flattened along +// with the other errors. +// +// errz.Combine(errz.Combine(err1, err2), err3) +// // is the same as +// errz.Combine(err1, err2, err3) +// +// The returned error formats into a readable multi-line error message if +// formatted with %+v. +// +// fmt.Sprintf("%+v", errz.Combine(err1, err2)) +func Combine(errors ...error) error { + switch len(errors) { + case 0: + return nil + case 1: + return Err(errors[0]) + } + return Err(multierr.Combine(errors...)) +} + +// Errors returns a slice containing zero or more errors that the supplied +// error is composed of. If the error is nil, a nil slice is returned. +// +// err := errz.Append(r.Close(), w.Close()) +// errors := errz.Errors(err) +// +// If the error is not composed of other errors, the returned slice contains +// just the error that was passed in. +// +// Callers of this function are free to modify the returned slice. +func Errors(err error) []error { + if err == nil { + return nil + } + + if me, ok := err.(multipleErrors); ok { + return me.Unwrap() + } + + ez, ok := err.(*errz) + if !ok { + return multierr.Errors(err) + } + + // It's an errz, so let's see what's underneath. + alien := ez.alienCause() + if alien == nil { + // It's not an alien error, it's just a pure errz error. + // It can't be a multi error. + return []error{err} + } + + // It's a foreign error, so we let multierr take care of it. + return multierr.Errors(alien) +} + +type multipleErrors interface { + Unwrap() []error +} diff --git a/libsq/core/errz/multi_test.go b/libsq/core/errz/multi_test.go new file mode 100644 index 000000000..62b9471ba --- /dev/null +++ b/libsq/core/errz/multi_test.go @@ -0,0 +1,48 @@ +package errz + +import ( + "errors" + "github.com/stretchr/testify/require" + "testing" +) + +func TestMultiErrors_stdlib_errors(t *testing.T) { + err1 := errors.New("err1") + err2 := errors.New("err2") + errs := Errors(err1) + require.Equal(t, []error{err1}, errs) + + appendErr := Append(err1, err2) + errs = Errors(appendErr) + require.Len(t, errs, 2) + require.Equal(t, []error{err1, err2}, errs) + t.Logf("%v", appendErr) + t.Logf("%+v", appendErr) + + stacks := Stacks(appendErr) + require.NotNil(t, stacks) + require.Len(t, stacks, 1) + st := stacks[0] + require.NotNil(t, st) + t.Logf("%+v", st.Frames) +} + +func TestMultiErrors_errz(t *testing.T) { + err1 := New("err1") + err2 := New("err2") + errs := Errors(err1) + require.Equal(t, []error{err1}, errs) + + appendErr := Append(err1, err2) + errs = Errors(appendErr) + require.Equal(t, []error{err1, err2}, errs) + t.Logf("%v", appendErr) + t.Logf("%+v", appendErr) + + stacks := Stacks(appendErr) + require.NotNil(t, stacks) + require.Len(t, stacks, 1) + st := stacks[0] + require.NotNil(t, st) + t.Logf("%+v", st.Frames) +} diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index e59ca8297..458e839ce 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -241,7 +241,7 @@ func (c *cache) clear(ctx context.Context) error { func (c *cache) doClear(ctx context.Context) error { deleteErr := errz.Wrap(os.RemoveAll(c.dir), "delete cache dir") recreateErr := ioz.RequireDir(c.dir) - err := errz.Combine(deleteErr, recreateErr) + err := errz.Append(deleteErr, recreateErr) if err != nil { lg.FromContext(ctx).Error(msgDeleteCache, lga.Dir, c.dir, lga.Err, err) From 686a9dae556ebff27892ba64937d3a28d34a0ffd Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 10:34:28 -0700 Subject: [PATCH 137/195] errz: switching to integrating multierr --- cli/cmd_x.go | 4 + libsq/core/errz/errors.go | 179 -------------------------------- libsq/core/errz/errz.go | 190 ++++++++++++++++++++++++++++++---- libsq/core/errz/errz_types.go | 14 +++ libsq/core/errz/multi.go | 42 ++++++-- libsq/core/errz/multi_test.go | 16 ++- libsq/core/errz/stack.go | 62 +++++++++-- 7 files changed, 291 insertions(+), 216 deletions(-) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 1a1755312..eccdb8f29 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -188,9 +188,13 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { switch { case len(h.Errors) > 0: + //err1 := errz.Err(h.Errors[0]) + //return err1 + err1 := h.Errors[0] err2 := errz.New("another err") return errz.Combine(err1, err2) + case len(h.WriteErrors) > 0: return h.WriteErrors[0] case len(h.CachedFiles) > 0: diff --git a/libsq/core/errz/errors.go b/libsq/core/errz/errors.go index 02f0958a5..f1ed88ea9 100644 --- a/libsq/core/errz/errors.go +++ b/libsq/core/errz/errors.go @@ -1,180 +1 @@ package errz - -// ACKNOWLEDGEMENT: The code in this file has its origins -// in Dave Cheney's pkg/errors. - -import ( - "errors" - "fmt" - "io" - "log/slog" - "strings" -) - -// New returns an error with the supplied message, recording the -// stack trace at the point it was called. -func New(message string) error { - return &errz{stack: callers(), msg: message} -} - -// Errorf works like [fmt.Errorf], but it also records the stack trace -// at the point it was called. If the format string includes the %w verb, -// [fmt.Errorf] is first called to construct the error, and then the -// returned error is again wrapped to record the stack trace. -func Errorf(format string, args ...any) error { - if strings.Contains(format, "%w") { - return &errz{stack: callers(), error: fmt.Errorf(format, args...)} - } - return &errz{stack: callers(), msg: fmt.Sprintf(format, args...)} -} - -// errz is the error implementation used by this package. -type errz struct { - error - msg string - *stack -} - -func (e *errz) stackTrace() *StackTrace { - if e == nil || e.stack == nil { - return nil - } - - st := e.stack.stackTrace() - if st != nil { - st.Error = e - } - return st -} - -// Error implements stdlib error interface. -func (e *errz) Error() string { - if e.msg == "" { - if e.error == nil { - return "" - } - return e.error.Error() - } - if e.error == nil { - return e.msg - } - return e.msg + ": " + e.error.Error() -} - -// LogValue implements slog.LogValuer. It returns a slog.GroupValue, -// having attributes "msg" and "type". If the error has a cause that -// from outside this package, the cause's type is included in a -// "cause" attribute. -func (e *errz) LogValue() slog.Value { - if e == nil { - return slog.Value{} - } - - attrs := make([]slog.Attr, 2, 3) - attrs[0] = slog.String("msg", e.Error()) - attrs[1] = slog.String("type", fmt.Sprintf("%T", e)) - - if cause := e.alienCause(); cause != nil { - attrs = append(attrs, slog.String("cause", fmt.Sprintf("%T", cause))) - } - - return slog.GroupValue(attrs...) -} - -// alienCause returns the first error in the chain that is -// not of type *errz, or returns nil if no such error. -func (e *errz) alienCause() error { - if e == nil { - return nil - } - - inner := e.error - for inner != nil { - // Note: don't use errors.As here; we want the direct type assertion. - if v, ok := inner.(*errz); ok { - inner = v.error - continue - } - return inner - } - return nil -} - -// Unwrap provides compatibility for Go 1.13 error chains. -func (e *errz) Unwrap() error { return e.error } - -// Format implements fmt.Formatter. -func (e *errz) Format(s fmt.State, verb rune) { - switch verb { - case 'v': - if s.Flag('+') { - if e.error == nil { - _, _ = io.WriteString(s, e.msg) - e.stack.Format(s, verb) - return - } else { - _, _ = fmt.Fprintf(s, "%+v", e.error) - e.stack.Format(s, verb) - } - return - } - fallthrough - case 's': - if e.error == nil { - _, _ = io.WriteString(s, e.msg) - return - } - _, _ = io.WriteString(s, e.Error()) - case 'q': - if e.error == nil { - _, _ = fmt.Fprintf(s, "{%s}", e.msg) - return - } - _, _ = fmt.Fprintf(s, "{%s}", e.Error()) - } -} - -// Wrap returns an error annotating err with a stack trace -// at the point Wrap is called, and the supplied message. -// If err is nil, Wrap returns nil. See also: Wrapf. -func Wrap(err error, message string) error { - if err == nil { - return nil - } - - return &errz{stack: callers(), error: err, msg: message} -} - -// Wrapf returns an error annotating err with a stack trace -// at the point Wrapf is called. Wrapf will panic if format -// includes the %w verb: use errz.Errorf for that situation. -// If err is nil, Wrapf returns nil. See also: Wrap, Errorf. -func Wrapf(err error, format string, args ...any) error { - if err == nil { - return nil - } - - if strings.Contains(format, "%w") { - panic("errz.Wrapf does not support %w verb: use errz.Errorf instead") - } - - return &errz{error: err, msg: fmt.Sprintf(format, args...), stack: callers()} -} - -// UnwrapChain returns the underlying *root* cause of the error. That is -// to say, UnwrapChain returns the final non-nil error in the error chain. -// UnwrapChain returns nil if err is nil. -func UnwrapChain(err error) error { - if err == nil { - return nil - } - - var cause error - for { - if cause = errors.Unwrap(err); cause == nil { - break - } - err = cause - } - return err -} diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index 5ad94546f..ab83ab356 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -1,44 +1,196 @@ // Package errz is sq's error package. It exists to combine // functionality from several error packages, including -// annotating errors with stack trace. Most of it comes -// from pkg/errors. -// -// At some point this package may become redundant, particularly in -// light of the proposed stdlib multiple error support: -// https://github.com/golang/go/issues/53435 +// annotating errors with stack trace. Much of the code originated +// in Dave Cheney's pkg/errors. package errz import ( - "context" "errors" + "fmt" + "io" + "log/slog" + "strings" ) // Err annotates err with a stack trace at the point Err was called. -// If err is nil, Err returns nil. +// It is equivalent to Wrap(err, ""). If err is nil, Err returns nil. func Err(err error) error { if err == nil { return nil } - return &errz{ - err, - "", - callers(), + return &errz{stack: callers(), error: err} +} + +// New returns an error with the supplied message, recording the +// stack trace at the point it was called. +func New(message string) error { + return &errz{stack: callers(), msg: message} +} + +// Errorf works like [fmt.Errorf], but it also records the stack trace +// at the point it was called. If the format string includes the %w verb, +// [fmt.Errorf] is first called to construct the error, and then the +// returned error is again wrapped to record the stack trace. +func Errorf(format string, args ...any) error { + if strings.Contains(format, "%w") { + return &errz{stack: callers(), error: fmt.Errorf(format, args...)} } + return &errz{stack: callers(), msg: fmt.Sprintf(format, args...)} +} + +// errz is our error type that does the magic. +type errz struct { + error + msg string + *stack +} + +// inner implements stackTracer. +func (e *errz) inner() error { return e.error } + +// stackTrace implements stackTracer. +func (e *errz) stackTrace() *StackTrace { + if e == nil || e.stack == nil { + return nil + } + + st := e.stack.stackTrace() + if st != nil { + st.Error = e + } + return st +} + +// Error implements stdlib error interface. +func (e *errz) Error() string { + if e.msg == "" { + if e.error == nil { + return "" + } + return e.error.Error() + } + if e.error == nil { + return e.msg + } + return e.msg + ": " + e.error.Error() +} + +// LogValue implements [slog.LogValuer]. It returns a [slog.GroupValue], +// having attributes "msg" and "type". If the error has a cause that +// from outside this package, the cause's type is included in a +// "cause" attribute. +func (e *errz) LogValue() slog.Value { + if e == nil { + return slog.Value{} + } + + attrs := make([]slog.Attr, 2, 3) + attrs[0] = slog.String("msg", e.Error()) + attrs[1] = slog.String("type", fmt.Sprintf("%T", e)) + + if cause := e.alienCause(); cause != nil { + attrs = append(attrs, slog.String("cause", fmt.Sprintf("%T", cause))) + } + + return slog.GroupValue(attrs...) } -// IsErrContext returns true if err is context.Canceled or context.DeadlineExceeded. -func IsErrContext(err error) bool { +// alienCause returns the first error in the chain that is +// not of type *errz, or returns nil if no such error. +func (e *errz) alienCause() error { + if e == nil { + return nil + } + + inner := e.error + for inner != nil { + // Note: don't use errors.As here; we want the direct type assertion. + if v, ok := inner.(*errz); ok { + inner = v.error + continue + } + return inner + } + return nil +} + +// Unwrap provides compatibility for Go 1.13 error chains. +func (e *errz) Unwrap() error { return e.error } + +// Format implements fmt.Formatter. +func (e *errz) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + if e.error == nil { + _, _ = io.WriteString(s, e.msg) + e.stack.Format(s, verb) + return + } else { + _, _ = fmt.Fprintf(s, "%+v", e.error) + e.stack.Format(s, verb) + } + return + } + fallthrough + case 's': + if e.error == nil { + _, _ = io.WriteString(s, e.msg) + return + } + _, _ = io.WriteString(s, e.Error()) + case 'q': + if e.error == nil { + _, _ = fmt.Fprintf(s, "{%s}", e.msg) + return + } + _, _ = fmt.Fprintf(s, "{%s}", e.Error()) + } +} + +// Wrap returns an error annotating err with a stack trace +// at the point Wrap is called, and the supplied message. +// If err is nil, Wrap returns nil. See also: Wrapf. +func Wrap(err error, message string) error { if err == nil { - return false + return nil } - errors.Join() + return &errz{stack: callers(), error: err, msg: message} +} + +// Wrapf returns an error annotating err with a stack trace +// at the point Wrapf is called. Wrapf will panic if format +// includes the %w verb: use errz.Errorf for that situation. +// If err is nil, Wrapf returns nil. See also: Wrap, Errorf. +func Wrapf(err error, format string, args ...any) error { + if err == nil { + return nil + } + + if strings.Contains(format, "%w") { + panic("errz.Wrapf does not support %w verb: use errz.Errorf instead") + } + + return &errz{error: err, msg: fmt.Sprintf(format, args...), stack: callers()} +} - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return true +// UnwrapChain returns the underlying *root* cause of the error. That is +// to say, UnwrapChain returns the final non-nil error in the error chain. +// UnwrapChain returns nil if err is nil. +func UnwrapChain(err error) error { + if err == nil { + return nil } - return false + var cause error + for { + if cause = errors.Unwrap(err); cause == nil { + break + } + err = cause + } + return err } // Return returns t with err wrapped via [errz.Err]. diff --git a/libsq/core/errz/errz_types.go b/libsq/core/errz/errz_types.go index 82cc7c4f0..ac2463c5b 100644 --- a/libsq/core/errz/errz_types.go +++ b/libsq/core/errz/errz_types.go @@ -1,6 +1,7 @@ package errz import ( + "context" "errors" ) @@ -68,3 +69,16 @@ func IsErrNoData(err error) bool { var e *NoDataError return errors.As(err, &e) } + +// IsErrContext returns true if err is context.Canceled or context.DeadlineExceeded. +func IsErrContext(err error) bool { + if err == nil { + return false + } + + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return true + } + + return false +} diff --git a/libsq/core/errz/multi.go b/libsq/core/errz/multi.go index ffba76909..0a83e26c2 100644 --- a/libsq/core/errz/multi.go +++ b/libsq/core/errz/multi.go @@ -27,13 +27,28 @@ import "go.uber.org/multierr" // the defer statement. func Append(left error, right error) error { switch { + case left == nil && right == nil: + return nil case left == nil: - return Err(right) + if _, ok := right.(*errz); !ok { + // It's not an errz, so we need to wrap it. + return &errz{stack: callers(), error: right} + } + return right case right == nil: - return Err(left) + if _, ok := left.(*errz); !ok { + // It's not an errz, so we need to wrap it. + return &errz{stack: callers(), error: left} + } + return left + } + + if me := multierr.Append(left, right); me == nil { + return nil + } else { + return &errz{stack: callers(), error: me} } - return Err(multierr.Append(left, right)) } // Combine combines the passed errors into a single error. @@ -43,7 +58,8 @@ func Append(left error, right error) error { // // Combine(nil, nil) // == nil // -// If only a single error was passed, it is returned as-is. +// If only a single error was passed, it is returned as-is if it's already +// an errz error; otherwise, it is wrapped before return. // // Combine(err) // == err // @@ -72,9 +88,23 @@ func Combine(errors ...error) error { case 0: return nil case 1: - return Err(errors[0]) + if errors[0] == nil { + return nil + } + + if _, ok := errors[0].(*errz); ok { + // It's already an errz, so we don't need to wrap it. + return errors[0] + } + + return &errz{stack: callers(), error: errors[0]} + } + + if me := multierr.Combine(errors...); me == nil { + return nil + } else { + return &errz{stack: callers(), error: me} } - return Err(multierr.Combine(errors...)) } // Errors returns a slice containing zero or more errors that the supplied diff --git a/libsq/core/errz/multi_test.go b/libsq/core/errz/multi_test.go index 62b9471ba..7f9cc07b1 100644 --- a/libsq/core/errz/multi_test.go +++ b/libsq/core/errz/multi_test.go @@ -6,7 +6,7 @@ import ( "testing" ) -func TestMultiErrors_stdlib_errors(t *testing.T) { +func TestAppend_stdlib_errors(t *testing.T) { err1 := errors.New("err1") err2 := errors.New("err2") errs := Errors(err1) @@ -25,6 +25,20 @@ func TestMultiErrors_stdlib_errors(t *testing.T) { st := stacks[0] require.NotNil(t, st) t.Logf("%+v", st.Frames) + + appendErr = Append(nil, nil) + require.Nil(t, appendErr) + + appendErr = Append(err1, nil) + require.NotNil(t, appendErr) + errs = Errors(appendErr) + require.Len(t, errs, 1) + gotErr1 := errs[0] + _, ok := gotErr1.(*errz) + require.True(t, ok) + gotErr1Unwrap := errors.Unwrap(gotErr1) + require.Equal(t, err1, gotErr1Unwrap) + } func TestMultiErrors_errz(t *testing.T) { diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index 66dd8b7db..ab4c63518 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -181,6 +181,11 @@ func (s *stack) Format(st fmt.State, verb rune) { } } +type stackTracer interface { + stackTrace() *StackTrace + inner() error +} + func (s *stack) stackTrace() *StackTrace { f := make([]Frame, len(*s)) for i := 0; i < len(f); i++ { @@ -218,13 +223,14 @@ func Stacks(err error) []*StackTrace { var stacks []*StackTrace for err != nil { - if ez, ok := err.(*errz); ok { - st := ez.stackTrace() + if tracer, ok := err.(stackTracer); ok { + st := tracer.stackTrace() if st != nil { stacks = append(stacks, st) } } + //err = errors.Unwrap(err) err = errors.Unwrap(err) } @@ -243,24 +249,58 @@ func LastStack(err error) *StackTrace { return nil } - var ez *errz - var ok bool + var ( + //var ez *errz + ok bool + tracer stackTracer + inner error + ) for err != nil { - ez, ok = err.(*errz) - if !ok || ez == nil { + tracer, ok = err.(stackTracer) + if !ok || tracer == nil { return nil } - if ez.error == nil { - return ez.stackTrace() + inner = tracer.inner() + if inner == nil { + return tracer.stackTrace() } - if _, ok = ez.error.(*errz); !ok { - return ez.stackTrace() + if _, ok = inner.(stackTracer); !ok { + return tracer.stackTrace() } - err = ez.error + err = inner } return nil } + +// +//func LastStackLegacy(err error) *StackTrace { // FIXME: delete this +// if err == nil { +// return nil +// } +// +// var ez *errz +// var ok bool +// var tracer stackTracer +// for err != nil { +// tracer, ok = err.(stackTracer) +// if !ok || ez == nil { +// return nil +// } +// +// if ez.error == nil { +// return ez.stackTrace() +// } +// +// if _, ok = ez.error.(*errz); !ok { +// return ez.stackTrace() +// } +// +// err = ez.error +// } +// +// return nil +//} From 3c078b2fdd8adb63dd27aa04ef7fc37673e8f6f2 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 10:34:40 -0700 Subject: [PATCH 138/195] errz: switching to integrating multierr --- libsq/core/errz/errors.go | 1 - 1 file changed, 1 deletion(-) delete mode 100644 libsq/core/errz/errors.go diff --git a/libsq/core/errz/errors.go b/libsq/core/errz/errors.go deleted file mode 100644 index f1ed88ea9..000000000 --- a/libsq/core/errz/errors.go +++ /dev/null @@ -1 +0,0 @@ -package errz From 25dc1627067729e8d63edf3bed9ebbbc27c24fa0 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 11:06:33 -0700 Subject: [PATCH 139/195] multierr integrated --- libsq/core/errz/errz.go | 23 +- libsq/core/errz/multi.go | 45 ++- libsq/core/errz/multi_test.go | 4 +- libsq/core/errz/multierr.go | 497 +++++++++++++++++++++++++ libsq/core/errz/multierr_post_go120.go | 23 ++ libsq/core/errz/stack.go | 5 +- 6 files changed, 569 insertions(+), 28 deletions(-) create mode 100644 libsq/core/errz/multierr.go create mode 100644 libsq/core/errz/multierr_post_go120.go diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index ab83ab356..a6f591652 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -1,7 +1,10 @@ -// Package errz is sq's error package. It exists to combine -// functionality from several error packages, including -// annotating errors with stack trace. Much of the code originated -// in Dave Cheney's pkg/errors. +// Package errz is sq's error package. It annotates errors with stack traces, +// and provides functionality for working with multiple errors, and error +// chains. +// +// This package is the lovechild of Dave Cheney's pkg/errors and +// Uber's go.uber.org/multierr, and much of the code is borrowed +// from those packages. package errz import ( @@ -18,13 +21,13 @@ func Err(err error) error { if err == nil { return nil } - return &errz{stack: callers(), error: err} + return &errz{stack: callers(0), error: err} } // New returns an error with the supplied message, recording the // stack trace at the point it was called. func New(message string) error { - return &errz{stack: callers(), msg: message} + return &errz{stack: callers(0), msg: message} } // Errorf works like [fmt.Errorf], but it also records the stack trace @@ -33,9 +36,9 @@ func New(message string) error { // returned error is again wrapped to record the stack trace. func Errorf(format string, args ...any) error { if strings.Contains(format, "%w") { - return &errz{stack: callers(), error: fmt.Errorf(format, args...)} + return &errz{stack: callers(0), error: fmt.Errorf(format, args...)} } - return &errz{stack: callers(), msg: fmt.Sprintf(format, args...)} + return &errz{stack: callers(0), msg: fmt.Sprintf(format, args...)} } // errz is our error type that does the magic. @@ -156,7 +159,7 @@ func Wrap(err error, message string) error { return nil } - return &errz{stack: callers(), error: err, msg: message} + return &errz{stack: callers(0), error: err, msg: message} } // Wrapf returns an error annotating err with a stack trace @@ -172,7 +175,7 @@ func Wrapf(err error, format string, args ...any) error { panic("errz.Wrapf does not support %w verb: use errz.Errorf instead") } - return &errz{error: err, msg: fmt.Sprintf(format, args...), stack: callers()} + return &errz{error: err, msg: fmt.Sprintf(format, args...), stack: callers(0)} } // UnwrapChain returns the underlying *root* cause of the error. That is diff --git a/libsq/core/errz/multi.go b/libsq/core/errz/multi.go index 0a83e26c2..b3c452a2c 100644 --- a/libsq/core/errz/multi.go +++ b/libsq/core/errz/multi.go @@ -4,10 +4,10 @@ package errz // package was in use before stdlib introduced the errors.Join function. // It's possible, maybe even desirable, to refactor these functions // to use stdlib errors instead. - +// import "go.uber.org/multierr" -// Append appends the given errors together. Either value may be nil. +// Append2 appends the given errors together. Either value may be nil. // // This function is a specialization of Combine for the common case where // there are only two errors. @@ -25,20 +25,20 @@ import "go.uber.org/multierr" // // Note that the variable MUST be a named return to append an error to it from // the defer statement. -func Append(left error, right error) error { +func Append2(left error, right error) error { switch { case left == nil && right == nil: return nil case left == nil: if _, ok := right.(*errz); !ok { // It's not an errz, so we need to wrap it. - return &errz{stack: callers(), error: right} + return &errz{stack: callers(0), error: right} } return right case right == nil: if _, ok := left.(*errz); !ok { // It's not an errz, so we need to wrap it. - return &errz{stack: callers(), error: left} + return &errz{stack: callers(0), error: left} } return left } @@ -46,12 +46,12 @@ func Append(left error, right error) error { if me := multierr.Append(left, right); me == nil { return nil } else { - return &errz{stack: callers(), error: me} + return &errz{stack: callers(0), error: me} } } -// Combine combines the passed errors into a single error. +// Combine2 combines the passed errors into a single error. // // If zero arguments were passed or if all items are nil, a nil error is // returned. @@ -83,7 +83,7 @@ func Append(left error, right error) error { // formatted with %+v. // // fmt.Sprintf("%+v", errz.Combine(err1, err2)) -func Combine(errors ...error) error { +func Combine2(errors ...error) error { switch len(errors) { case 0: return nil @@ -97,17 +97,17 @@ func Combine(errors ...error) error { return errors[0] } - return &errz{stack: callers(), error: errors[0]} + return &errz{stack: callers(0), error: errors[0]} } if me := multierr.Combine(errors...); me == nil { return nil } else { - return &errz{stack: callers(), error: me} + return &errz{stack: callers(0), error: me} } } -// Errors returns a slice containing zero or more errors that the supplied +// Errors2 returns a slice containing zero or more errors that the supplied // error is composed of. If the error is nil, a nil slice is returned. // // err := errz.Append(r.Close(), w.Close()) @@ -117,7 +117,7 @@ func Combine(errors ...error) error { // just the error that was passed in. // // Callers of this function are free to modify the returned slice. -func Errors(err error) []error { +func Errors2(err error) []error { if err == nil { return nil } @@ -143,6 +143,23 @@ func Errors(err error) []error { return multierr.Errors(alien) } -type multipleErrors interface { - Unwrap() []error +// +//type multipleErrors interface { +// Unwrap() []error +//} + +// inner implements stackTracer. +func (merr *multiError) inner() error { return nil } + +// stackTrace implements stackTracer. +func (merr *multiError) stackTrace() *StackTrace { + if merr == nil || merr.stack == nil { + return nil + } + + st := merr.stack.stackTrace() + if st != nil { + st.Error = merr + } + return st } diff --git a/libsq/core/errz/multi_test.go b/libsq/core/errz/multi_test.go index 7f9cc07b1..51026cffe 100644 --- a/libsq/core/errz/multi_test.go +++ b/libsq/core/errz/multi_test.go @@ -50,8 +50,8 @@ func TestMultiErrors_errz(t *testing.T) { appendErr := Append(err1, err2) errs = Errors(appendErr) require.Equal(t, []error{err1, err2}, errs) - t.Logf("%v", appendErr) - t.Logf("%+v", appendErr) + //t.Logf("%v", appendErr) + //t.Logf("%+v", appendErr) stacks := Stacks(appendErr) require.NotNil(t, stacks) diff --git a/libsq/core/errz/multierr.go b/libsq/core/errz/multierr.go new file mode 100644 index 000000000..a4cbc5220 --- /dev/null +++ b/libsq/core/errz/multierr.go @@ -0,0 +1,497 @@ +package errz + +// ACKNOWLEDGEMENT: This code is lifted from uber's multierr package. + +import ( + "bytes" + "errors" + "fmt" + "io" + "strings" + "sync" + "sync/atomic" +) + +// Copyright (c) 2017-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +// Package multierr allows combining one or more errors together. +// +// # Overview +// +// Errors can be combined with the use of the Combine function. +// +// multierr.Combine( +// reader.Close(), +// writer.Close(), +// conn.Close(), +// ) +// +// If only two errors are being combined, the Append function may be used +// instead. +// +// err = multierr.Append(reader.Close(), writer.Close()) +// +// The underlying list of errors for a returned error object may be retrieved +// with the Errors function. +// +// errors := multierr.Errors(err) +// if len(errors) > 0 { +// fmt.Println("The following errors occurred:", errors) +// } +// +// # Appending from a loop +// +// You sometimes need to append into an error from a loop. +// +// var err error +// for _, item := range items { +// err = multierr.Append(err, process(item)) +// } +// +// Cases like this may require knowledge of whether an individual instance +// failed. This usually requires introduction of a new variable. +// +// var err error +// for _, item := range items { +// if perr := process(item); perr != nil { +// log.Warn("skipping item", item) +// err = multierr.Append(err, perr) +// } +// } +// +// multierr includes AppendInto to simplify cases like this. +// +// var err error +// for _, item := range items { +// if multierr.AppendInto(&err, process(item)) { +// log.Warn("skipping item", item) +// } +// } +// +// This will append the error into the err variable, and return true if that +// individual error was non-nil. +// +// See [AppendInto] for more information. +// +// # Deferred Functions +// +// Go makes it possible to modify the return value of a function in a defer +// block if the function was using named returns. This makes it possible to +// record resource cleanup failures from deferred blocks. +// +// func sendRequest(req Request) (err error) { +// conn, err := openConnection() +// if err != nil { +// return err +// } +// defer func() { +// err = multierr.Append(err, conn.Close()) +// }() +// // ... +// } +// +// multierr provides the Invoker type and AppendInvoke function to make cases +// like the above simpler and obviate the need for a closure. The following is +// roughly equivalent to the example above. +// +// func sendRequest(req Request) (err error) { +// conn, err := openConnection() +// if err != nil { +// return err +// } +// defer multierr.AppendInvoke(&err, multierr.Close(conn)) +// // ... +// } +// +// See [AppendInvoke] and [Invoker] for more information. +// +// NOTE: If you're modifying an error from inside a defer, you MUST use a named +// return value for that function. +// +// # Advanced Usage +// +// Errors returned by Combine and Append MAY implement the following +// interface. +// +// type errorGroup interface { +// // Returns a slice containing the underlying list of errors. +// // +// // This slice MUST NOT be modified by the caller. +// Errors() []error +// } +// +// Note that if you need access to list of errors behind a multierr error, you +// should prefer using the Errors function. That said, if you need cheap +// read-only access to the underlying errors slice, you can attempt to cast +// the error to this interface. You MUST handle the failure case gracefully +// because errors returned by Combine and Append are not guaranteed to +// implement this interface. +// +// var errors []error +// group, ok := err.(errorGroup) +// if ok { +// errors = group.Errors() +// } else { +// errors = []error{err} +// } + +var ( + // Separator for single-line error messages. + _singlelineSeparator = []byte("; ") + + // Prefix for multi-line messages + _multilinePrefix = []byte("the following errors occurred:") + + // Prefix for the first and following lines of an item in a list of + // multi-line error messages. + // + // For example, if a single item is: + // + // foo + // bar + // + // It will become, + // + // - foo + // bar + _multilineSeparator = []byte("\n - ") + _multilineIndent = []byte(" ") +) + +// _bufferPool is a pool of bytes.Buffers. +var _bufferPool = sync.Pool{ + New: func() interface{} { + return &bytes.Buffer{} + }, +} + +type errorGroup interface { + Errors() []error +} + +// Errors returns a slice containing zero or more errors that the supplied +// error is composed of. If the error is nil, a nil slice is returned. +// +// err := multierr.Append(r.Close(), w.Close()) +// errors := multierr.Errors(err) +// +// If the error is not composed of other errors, the returned slice contains +// just the error that was passed in. +// +// Callers of this function are free to modify the returned slice. +func Errors(err error) []error { + return extractErrors(err) +} + +// multiError is an error that holds one or more errors. +// +// An instance of this is guaranteed to be non-empty and flattened. That is, +// none of the errors inside multiError are other multiErrors. +// +// multiError formats to a semi-colon delimited list of error messages with +// %v and with a more readable multi-line format with %+v. +type multiError struct { + copyNeeded atomic.Bool + errors []error + *stack +} + +// Errors returns the list of underlying errors. +// +// This slice MUST NOT be modified. +func (merr *multiError) Errors() []error { + if merr == nil { + return nil + } + return merr.errors +} + +func (merr *multiError) Error() string { + if merr == nil { + return "" + } + + buff := _bufferPool.Get().(*bytes.Buffer) + buff.Reset() + + merr.writeSingleline(buff) + + result := buff.String() + _bufferPool.Put(buff) + return result +} + +// Every compares every error in the given err against the given target error +// using [errors.Is], and returns true only if every comparison returned true. +func Every(err error, target error) bool { + for _, e := range extractErrors(err) { + if !errors.Is(e, target) { + return false + } + } + return true +} + +func (merr *multiError) Format(f fmt.State, c rune) { + if c == 'v' && f.Flag('+') { + merr.writeMultiline(f) + } else { + merr.writeSingleline(f) + } +} + +func (merr *multiError) writeSingleline(w io.Writer) { + first := true + for _, item := range merr.errors { + if first { + first = false + } else { + w.Write(_singlelineSeparator) + } + io.WriteString(w, item.Error()) + } +} + +func (merr *multiError) writeMultiline(w io.Writer) { + w.Write(_multilinePrefix) + for _, item := range merr.errors { + w.Write(_multilineSeparator) + writePrefixLine(w, _multilineIndent, fmt.Sprintf("%+v", item)) + } +} + +// Writes s to the writer with the given prefix added before each line after +// the first. +func writePrefixLine(w io.Writer, prefix []byte, s string) { + first := true + for len(s) > 0 { + if first { + first = false + } else { + w.Write(prefix) + } + + idx := strings.IndexByte(s, '\n') + if idx < 0 { + idx = len(s) - 1 + } + + io.WriteString(w, s[:idx+1]) + s = s[idx+1:] + } +} + +type inspectResult struct { + // Number of top-level non-nil errors + Count int + + // Total number of errors including multiErrors + Capacity int + + // Index of the first non-nil error in the list. Value is meaningless if + // Count is zero. + FirstErrorIdx int + + // Whether the list contains at least one multiError + ContainsMultiError bool +} + +// Inspects the given slice of errors so that we can efficiently allocate +// space for it. +func inspect(errors []error) (res inspectResult) { + first := true + for i, err := range errors { + if err == nil { + continue + } + + res.Count++ + if first { + first = false + res.FirstErrorIdx = i + } + + if merr, ok := err.(*multiError); ok { + res.Capacity += len(merr.errors) + res.ContainsMultiError = true + } else { + res.Capacity++ + } + } + return +} + +// fromSlice converts the given list of errors into a single error. +func fromSlice(errors []error) error { + // Don't pay to inspect small slices. + switch len(errors) { + case 0: + return nil + case 1: + return errors[0] + } + + res := inspect(errors) + switch res.Count { + case 0: + return nil + case 1: + // only one non-nil entry + return errors[res.FirstErrorIdx] + case len(errors): + if !res.ContainsMultiError { + // Error list is flat. Make a copy of it + // Otherwise "errors" escapes to the heap + // unconditionally for all other cases. + // This lets us optimize for the "no errors" case. + out := append(([]error)(nil), errors...) + return &multiError{errors: out, stack: callers(1)} + } + } + + nonNilErrs := make([]error, 0, res.Capacity) + for _, err := range errors[res.FirstErrorIdx:] { + if err == nil { + continue + } + + if nested, ok := err.(*multiError); ok { + nonNilErrs = append(nonNilErrs, nested.errors...) + } else { + nonNilErrs = append(nonNilErrs, err) + } + } + + return &multiError{errors: nonNilErrs, stack: callers(0)} +} + +// Combine combines the passed errors into a single error. +// +// If zero arguments were passed or if all items are nil, a nil error is +// returned. +// +// Combine(nil, nil) // == nil +// +// If only a single error was passed, it is returned as-is. +// +// Combine(err) // == err +// +// Combine skips over nil arguments so this function may be used to combine +// together errors from operations that fail independently of each other. +// +// multierr.Combine( +// reader.Close(), +// writer.Close(), +// pipe.Close(), +// ) +// +// If any of the passed errors is a multierr error, it will be flattened along +// with the other errors. +// +// multierr.Combine(multierr.Combine(err1, err2), err3) +// // is the same as +// multierr.Combine(err1, err2, err3) +// +// The returned error formats into a readable multi-line error message if +// formatted with %+v. +// +// fmt.Sprintf("%+v", multierr.Combine(err1, err2)) +func Combine(errors ...error) error { + return fromSlice(errors) +} + +// Append appends the given errors together. Either value may be nil. +// +// This function is a specialization of Combine for the common case where +// there are only two errors. +// +// err = multierr.Append(reader.Close(), writer.Close()) +// +// The following pattern may also be used to record failure of deferred +// operations without losing information about the original error. +// +// func doSomething(..) (err error) { +// f := acquireResource() +// defer func() { +// err = multierr.Append(err, f.Close()) +// }() +// +// Note that the variable MUST be a named return to append an error to it from +// the defer statement. +func Append(left error, right error) error { + switch { + case left == nil && right == nil: + return nil + case left == nil: + if _, ok := right.(*errz); !ok { + // It's not an errz, so we need to wrap it. + return &errz{stack: callers(0), error: right} + } + return right + case right == nil: + if _, ok := left.(*errz); !ok { + // It's not an errz, so we need to wrap it. + return &errz{stack: callers(0), error: left} + } + return left + } + + if _, ok := right.(*multiError); !ok { + if l, ok := left.(*multiError); ok && !l.copyNeeded.Swap(true) { + // Common case where the error on the left is constantly being + // appended to. + errs := append(l.errors, right) + return &multiError{errors: errs, stack: callers(0)} + } else if !ok { + // Both errors are single errors. + return &multiError{errors: []error{left, right}, stack: callers(0)} + } + } + + // Either right or both, left and right, are multiErrors. Rely on usual + // expensive logic. + errors := [2]error{left, right} + return fromSlice(errors[0:]) +} + +// Unwrap returns a list of errors wrapped by this multierr. +func (merr *multiError) Unwrap() []error { + return merr.Errors() +} + +type multipleErrors interface { + Unwrap() []error +} + +func extractErrors(err error) []error { + if err == nil { + return nil + } + + // check if the given err is an Unwrapable error that + // implements multipleErrors interface. + eg, ok := err.(multipleErrors) + if !ok { + return []error{err} + } + + return append(([]error)(nil), eg.Unwrap()...) +} diff --git a/libsq/core/errz/multierr_post_go120.go b/libsq/core/errz/multierr_post_go120.go new file mode 100644 index 000000000..c2736c6b7 --- /dev/null +++ b/libsq/core/errz/multierr_post_go120.go @@ -0,0 +1,23 @@ +// Copyright (c) 2017-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +//go:build go1.20 + +package errz diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index ab4c63518..c525dd160 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -194,11 +194,12 @@ func (s *stack) stackTrace() *StackTrace { return &StackTrace{Frames: f} } -func callers() *stack { +func callers(skip int) *stack { const depth = 32 var pcs [depth]uintptr n := runtime.Callers(3, pcs[:]) - var st stack = pcs[0:n] + //var st stack = pcs[0:n] + var st stack = pcs[skip:n] return &st } From a250e9a0533ba567796e0f75387a14c52ed323fb Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 11:07:58 -0700 Subject: [PATCH 140/195] errz cleanup --- libsq/core/errz/multi.go | 548 +++++++++++++++++++++++++++++------- libsq/core/errz/multierr.go | 496 -------------------------------- 2 files changed, 448 insertions(+), 596 deletions(-) diff --git a/libsq/core/errz/multi.go b/libsq/core/errz/multi.go index b3c452a2c..9f32822d2 100644 --- a/libsq/core/errz/multi.go +++ b/libsq/core/errz/multi.go @@ -1,72 +1,419 @@ package errz -// These multi-error functions delegate to go.uber.org/multierr. That -// package was in use before stdlib introduced the errors.Join function. -// It's possible, maybe even desirable, to refactor these functions -// to use stdlib errors instead. +// ACKNOWLEDGEMENT: This code is lifted from uber's multierr package. + +import ( + "bytes" + "errors" + "fmt" + "io" + "strings" + "sync" + "sync/atomic" +) + +// Copyright (c) 2017-2023 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. // -import "go.uber.org/multierr" +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. -// Append2 appends the given errors together. Either value may be nil. +// Package multierr allows combining one or more errors together. // -// This function is a specialization of Combine for the common case where -// there are only two errors. +// # Overview // -// err = errz.Append(reader.Close(), writer.Close()) +// Errors can be combined with the use of the Combine function. // -// The following pattern may also be used to record failure of deferred -// operations without losing information about the original error. +// multierr.Combine( +// reader.Close(), +// writer.Close(), +// conn.Close(), +// ) // -// func doSomething(..) (err error) { -// f := acquireResource() +// If only two errors are being combined, the Append function may be used +// instead. +// +// err = multierr.Append(reader.Close(), writer.Close()) +// +// The underlying list of errors for a returned error object may be retrieved +// with the Errors function. +// +// errors := multierr.Errors(err) +// if len(errors) > 0 { +// fmt.Println("The following errors occurred:", errors) +// } +// +// # Appending from a loop +// +// You sometimes need to append into an error from a loop. +// +// var err error +// for _, item := range items { +// err = multierr.Append(err, process(item)) +// } +// +// Cases like this may require knowledge of whether an individual instance +// failed. This usually requires introduction of a new variable. +// +// var err error +// for _, item := range items { +// if perr := process(item); perr != nil { +// log.Warn("skipping item", item) +// err = multierr.Append(err, perr) +// } +// } +// +// multierr includes AppendInto to simplify cases like this. +// +// var err error +// for _, item := range items { +// if multierr.AppendInto(&err, process(item)) { +// log.Warn("skipping item", item) +// } +// } +// +// This will append the error into the err variable, and return true if that +// individual error was non-nil. +// +// See [AppendInto] for more information. +// +// # Deferred Functions +// +// Go makes it possible to modify the return value of a function in a defer +// block if the function was using named returns. This makes it possible to +// record resource cleanup failures from deferred blocks. +// +// func sendRequest(req Request) (err error) { +// conn, err := openConnection() +// if err != nil { +// return err +// } // defer func() { -// err = errz.Append(err, f.Close()) +// err = multierr.Append(err, conn.Close()) // }() +// // ... +// } // -// Note that the variable MUST be a named return to append an error to it from -// the defer statement. -func Append2(left error, right error) error { - switch { - case left == nil && right == nil: +// multierr provides the Invoker type and AppendInvoke function to make cases +// like the above simpler and obviate the need for a closure. The following is +// roughly equivalent to the example above. +// +// func sendRequest(req Request) (err error) { +// conn, err := openConnection() +// if err != nil { +// return err +// } +// defer multierr.AppendInvoke(&err, multierr.Close(conn)) +// // ... +// } +// +// See [AppendInvoke] and [Invoker] for more information. +// +// NOTE: If you're modifying an error from inside a defer, you MUST use a named +// return value for that function. +// +// # Advanced Usage +// +// Errors returned by Combine and Append MAY implement the following +// interface. +// +// type errorGroup interface { +// // Returns a slice containing the underlying list of errors. +// // +// // This slice MUST NOT be modified by the caller. +// Errors() []error +// } +// +// Note that if you need access to list of errors behind a multierr error, you +// should prefer using the Errors function. That said, if you need cheap +// read-only access to the underlying errors slice, you can attempt to cast +// the error to this interface. You MUST handle the failure case gracefully +// because errors returned by Combine and Append are not guaranteed to +// implement this interface. +// +// var errors []error +// group, ok := err.(errorGroup) +// if ok { +// errors = group.Errors() +// } else { +// errors = []error{err} +// } + +var ( + // Separator for single-line error messages. + _singlelineSeparator = []byte("; ") + + // Prefix for multi-line messages + _multilinePrefix = []byte("the following errors occurred:") + + // Prefix for the first and following lines of an item in a list of + // multi-line error messages. + // + // For example, if a single item is: + // + // foo + // bar + // + // It will become, + // + // - foo + // bar + _multilineSeparator = []byte("\n - ") + _multilineIndent = []byte(" ") +) + +// _bufferPool is a pool of bytes.Buffers. +var _bufferPool = sync.Pool{ + New: func() interface{} { + return &bytes.Buffer{} + }, +} + +type errorGroup interface { + Errors() []error +} + +// Errors returns a slice containing zero or more errors that the supplied +// error is composed of. If the error is nil, a nil slice is returned. +// +// err := multierr.Append(r.Close(), w.Close()) +// errors := multierr.Errors(err) +// +// If the error is not composed of other errors, the returned slice contains +// just the error that was passed in. +// +// Callers of this function are free to modify the returned slice. +func Errors(err error) []error { + return extractErrors(err) +} + +// multiError is an error that holds one or more errors. +// +// An instance of this is guaranteed to be non-empty and flattened. That is, +// none of the errors inside multiError are other multiErrors. +// +// multiError formats to a semi-colon delimited list of error messages with +// %v and with a more readable multi-line format with %+v. +type multiError struct { + copyNeeded atomic.Bool + errors []error + *stack +} + +// inner implements stackTracer. +func (merr *multiError) inner() error { return nil } + +// stackTrace implements stackTracer. +func (merr *multiError) stackTrace() *StackTrace { + if merr == nil || merr.stack == nil { return nil - case left == nil: - if _, ok := right.(*errz); !ok { - // It's not an errz, so we need to wrap it. - return &errz{stack: callers(0), error: right} + } + + st := merr.stack.stackTrace() + if st != nil { + st.Error = merr + } + return st +} + +// Errors returns the list of underlying errors. +// +// This slice MUST NOT be modified. +func (merr *multiError) Errors() []error { + if merr == nil { + return nil + } + return merr.errors +} + +func (merr *multiError) Error() string { + if merr == nil { + return "" + } + + buff := _bufferPool.Get().(*bytes.Buffer) + buff.Reset() + + merr.writeSingleline(buff) + + result := buff.String() + _bufferPool.Put(buff) + return result +} + +// Every compares every error in the given err against the given target error +// using [errors.Is], and returns true only if every comparison returned true. +func Every(err error, target error) bool { + for _, e := range extractErrors(err) { + if !errors.Is(e, target) { + return false } - return right - case right == nil: - if _, ok := left.(*errz); !ok { - // It's not an errz, so we need to wrap it. - return &errz{stack: callers(0), error: left} + } + return true +} + +func (merr *multiError) Format(f fmt.State, c rune) { + if c == 'v' && f.Flag('+') { + merr.writeMultiline(f) + } else { + merr.writeSingleline(f) + } +} + +func (merr *multiError) writeSingleline(w io.Writer) { + first := true + for _, item := range merr.errors { + if first { + first = false + } else { + w.Write(_singlelineSeparator) } - return left + io.WriteString(w, item.Error()) + } +} + +func (merr *multiError) writeMultiline(w io.Writer) { + w.Write(_multilinePrefix) + for _, item := range merr.errors { + w.Write(_multilineSeparator) + writePrefixLine(w, _multilineIndent, fmt.Sprintf("%+v", item)) + } +} + +// Writes s to the writer with the given prefix added before each line after +// the first. +func writePrefixLine(w io.Writer, prefix []byte, s string) { + first := true + for len(s) > 0 { + if first { + first = false + } else { + w.Write(prefix) + } + + idx := strings.IndexByte(s, '\n') + if idx < 0 { + idx = len(s) - 1 + } + + io.WriteString(w, s[:idx+1]) + s = s[idx+1:] } +} + +type inspectResult struct { + // Number of top-level non-nil errors + Count int - if me := multierr.Append(left, right); me == nil { + // Total number of errors including multiErrors + Capacity int + + // Index of the first non-nil error in the list. Value is meaningless if + // Count is zero. + FirstErrorIdx int + + // Whether the list contains at least one multiError + ContainsMultiError bool +} + +// Inspects the given slice of errors so that we can efficiently allocate +// space for it. +func inspect(errors []error) (res inspectResult) { + first := true + for i, err := range errors { + if err == nil { + continue + } + + res.Count++ + if first { + first = false + res.FirstErrorIdx = i + } + + if merr, ok := err.(*multiError); ok { + res.Capacity += len(merr.errors) + res.ContainsMultiError = true + } else { + res.Capacity++ + } + } + return +} + +// fromSlice converts the given list of errors into a single error. +func fromSlice(errors []error) error { + // Don't pay to inspect small slices. + switch len(errors) { + case 0: return nil - } else { - return &errz{stack: callers(0), error: me} + case 1: + return errors[0] } + res := inspect(errors) + switch res.Count { + case 0: + return nil + case 1: + // only one non-nil entry + return errors[res.FirstErrorIdx] + case len(errors): + if !res.ContainsMultiError { + // Error list is flat. Make a copy of it + // Otherwise "errors" escapes to the heap + // unconditionally for all other cases. + // This lets us optimize for the "no errors" case. + out := append(([]error)(nil), errors...) + return &multiError{errors: out, stack: callers(1)} + } + } + + nonNilErrs := make([]error, 0, res.Capacity) + for _, err := range errors[res.FirstErrorIdx:] { + if err == nil { + continue + } + + if nested, ok := err.(*multiError); ok { + nonNilErrs = append(nonNilErrs, nested.errors...) + } else { + nonNilErrs = append(nonNilErrs, err) + } + } + + return &multiError{errors: nonNilErrs, stack: callers(0)} } -// Combine2 combines the passed errors into a single error. +// Combine combines the passed errors into a single error. // // If zero arguments were passed or if all items are nil, a nil error is // returned. // // Combine(nil, nil) // == nil // -// If only a single error was passed, it is returned as-is if it's already -// an errz error; otherwise, it is wrapped before return. +// If only a single error was passed, it is returned as-is. // // Combine(err) // == err // // Combine skips over nil arguments so this function may be used to combine // together errors from operations that fail independently of each other. // -// errz.Combine( +// multierr.Combine( // reader.Close(), // writer.Close(), // pipe.Close(), @@ -75,91 +422,92 @@ func Append2(left error, right error) error { // If any of the passed errors is a multierr error, it will be flattened along // with the other errors. // -// errz.Combine(errz.Combine(err1, err2), err3) +// multierr.Combine(multierr.Combine(err1, err2), err3) // // is the same as -// errz.Combine(err1, err2, err3) +// multierr.Combine(err1, err2, err3) // // The returned error formats into a readable multi-line error message if // formatted with %+v. // -// fmt.Sprintf("%+v", errz.Combine(err1, err2)) -func Combine2(errors ...error) error { - switch len(errors) { - case 0: - return nil - case 1: - if errors[0] == nil { - return nil - } - - if _, ok := errors[0].(*errz); ok { - // It's already an errz, so we don't need to wrap it. - return errors[0] - } - - return &errz{stack: callers(0), error: errors[0]} - } - - if me := multierr.Combine(errors...); me == nil { - return nil - } else { - return &errz{stack: callers(0), error: me} - } +// fmt.Sprintf("%+v", multierr.Combine(err1, err2)) +func Combine(errors ...error) error { + return fromSlice(errors) } -// Errors2 returns a slice containing zero or more errors that the supplied -// error is composed of. If the error is nil, a nil slice is returned. +// Append appends the given errors together. Either value may be nil. // -// err := errz.Append(r.Close(), w.Close()) -// errors := errz.Errors(err) +// This function is a specialization of Combine for the common case where +// there are only two errors. // -// If the error is not composed of other errors, the returned slice contains -// just the error that was passed in. +// err = multierr.Append(reader.Close(), writer.Close()) // -// Callers of this function are free to modify the returned slice. -func Errors2(err error) []error { - if err == nil { +// The following pattern may also be used to record failure of deferred +// operations without losing information about the original error. +// +// func doSomething(..) (err error) { +// f := acquireResource() +// defer func() { +// err = multierr.Append(err, f.Close()) +// }() +// +// Note that the variable MUST be a named return to append an error to it from +// the defer statement. +func Append(left error, right error) error { + switch { + case left == nil && right == nil: return nil + case left == nil: + if _, ok := right.(*errz); !ok { + // It's not an errz, so we need to wrap it. + return &errz{stack: callers(0), error: right} + } + return right + case right == nil: + if _, ok := left.(*errz); !ok { + // It's not an errz, so we need to wrap it. + return &errz{stack: callers(0), error: left} + } + return left } - if me, ok := err.(multipleErrors); ok { - return me.Unwrap() - } - - ez, ok := err.(*errz) - if !ok { - return multierr.Errors(err) - } - - // It's an errz, so let's see what's underneath. - alien := ez.alienCause() - if alien == nil { - // It's not an alien error, it's just a pure errz error. - // It can't be a multi error. - return []error{err} + if _, ok := right.(*multiError); !ok { + if l, ok := left.(*multiError); ok && !l.copyNeeded.Swap(true) { + // Common case where the error on the left is constantly being + // appended to. + errs := append(l.errors, right) + return &multiError{errors: errs, stack: callers(0)} + } else if !ok { + // Both errors are single errors. + return &multiError{errors: []error{left, right}, stack: callers(0)} + } } - // It's a foreign error, so we let multierr take care of it. - return multierr.Errors(alien) + // Either right or both, left and right, are multiErrors. Rely on usual + // expensive logic. + errors := [2]error{left, right} + return fromSlice(errors[0:]) } -// -//type multipleErrors interface { -// Unwrap() []error -//} +// Unwrap returns a list of errors wrapped by this multierr. +func (merr *multiError) Unwrap() []error { + return merr.Errors() +} -// inner implements stackTracer. -func (merr *multiError) inner() error { return nil } +type multipleErrors interface { + Unwrap() []error +} -// stackTrace implements stackTracer. -func (merr *multiError) stackTrace() *StackTrace { - if merr == nil || merr.stack == nil { +func extractErrors(err error) []error { + if err == nil { return nil } - st := merr.stack.stackTrace() - if st != nil { - st.Error = merr + // check if the given err is an Unwrapable error that + // implements multipleErrors interface. + eg, ok := err.(multipleErrors) + if !ok { + return []error{err} } - return st + + return append(([]error)(nil), eg.Unwrap()...) } diff --git a/libsq/core/errz/multierr.go b/libsq/core/errz/multierr.go index a4cbc5220..f1ed88ea9 100644 --- a/libsq/core/errz/multierr.go +++ b/libsq/core/errz/multierr.go @@ -1,497 +1 @@ package errz - -// ACKNOWLEDGEMENT: This code is lifted from uber's multierr package. - -import ( - "bytes" - "errors" - "fmt" - "io" - "strings" - "sync" - "sync/atomic" -) - -// Copyright (c) 2017-2023 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -// Package multierr allows combining one or more errors together. -// -// # Overview -// -// Errors can be combined with the use of the Combine function. -// -// multierr.Combine( -// reader.Close(), -// writer.Close(), -// conn.Close(), -// ) -// -// If only two errors are being combined, the Append function may be used -// instead. -// -// err = multierr.Append(reader.Close(), writer.Close()) -// -// The underlying list of errors for a returned error object may be retrieved -// with the Errors function. -// -// errors := multierr.Errors(err) -// if len(errors) > 0 { -// fmt.Println("The following errors occurred:", errors) -// } -// -// # Appending from a loop -// -// You sometimes need to append into an error from a loop. -// -// var err error -// for _, item := range items { -// err = multierr.Append(err, process(item)) -// } -// -// Cases like this may require knowledge of whether an individual instance -// failed. This usually requires introduction of a new variable. -// -// var err error -// for _, item := range items { -// if perr := process(item); perr != nil { -// log.Warn("skipping item", item) -// err = multierr.Append(err, perr) -// } -// } -// -// multierr includes AppendInto to simplify cases like this. -// -// var err error -// for _, item := range items { -// if multierr.AppendInto(&err, process(item)) { -// log.Warn("skipping item", item) -// } -// } -// -// This will append the error into the err variable, and return true if that -// individual error was non-nil. -// -// See [AppendInto] for more information. -// -// # Deferred Functions -// -// Go makes it possible to modify the return value of a function in a defer -// block if the function was using named returns. This makes it possible to -// record resource cleanup failures from deferred blocks. -// -// func sendRequest(req Request) (err error) { -// conn, err := openConnection() -// if err != nil { -// return err -// } -// defer func() { -// err = multierr.Append(err, conn.Close()) -// }() -// // ... -// } -// -// multierr provides the Invoker type and AppendInvoke function to make cases -// like the above simpler and obviate the need for a closure. The following is -// roughly equivalent to the example above. -// -// func sendRequest(req Request) (err error) { -// conn, err := openConnection() -// if err != nil { -// return err -// } -// defer multierr.AppendInvoke(&err, multierr.Close(conn)) -// // ... -// } -// -// See [AppendInvoke] and [Invoker] for more information. -// -// NOTE: If you're modifying an error from inside a defer, you MUST use a named -// return value for that function. -// -// # Advanced Usage -// -// Errors returned by Combine and Append MAY implement the following -// interface. -// -// type errorGroup interface { -// // Returns a slice containing the underlying list of errors. -// // -// // This slice MUST NOT be modified by the caller. -// Errors() []error -// } -// -// Note that if you need access to list of errors behind a multierr error, you -// should prefer using the Errors function. That said, if you need cheap -// read-only access to the underlying errors slice, you can attempt to cast -// the error to this interface. You MUST handle the failure case gracefully -// because errors returned by Combine and Append are not guaranteed to -// implement this interface. -// -// var errors []error -// group, ok := err.(errorGroup) -// if ok { -// errors = group.Errors() -// } else { -// errors = []error{err} -// } - -var ( - // Separator for single-line error messages. - _singlelineSeparator = []byte("; ") - - // Prefix for multi-line messages - _multilinePrefix = []byte("the following errors occurred:") - - // Prefix for the first and following lines of an item in a list of - // multi-line error messages. - // - // For example, if a single item is: - // - // foo - // bar - // - // It will become, - // - // - foo - // bar - _multilineSeparator = []byte("\n - ") - _multilineIndent = []byte(" ") -) - -// _bufferPool is a pool of bytes.Buffers. -var _bufferPool = sync.Pool{ - New: func() interface{} { - return &bytes.Buffer{} - }, -} - -type errorGroup interface { - Errors() []error -} - -// Errors returns a slice containing zero or more errors that the supplied -// error is composed of. If the error is nil, a nil slice is returned. -// -// err := multierr.Append(r.Close(), w.Close()) -// errors := multierr.Errors(err) -// -// If the error is not composed of other errors, the returned slice contains -// just the error that was passed in. -// -// Callers of this function are free to modify the returned slice. -func Errors(err error) []error { - return extractErrors(err) -} - -// multiError is an error that holds one or more errors. -// -// An instance of this is guaranteed to be non-empty and flattened. That is, -// none of the errors inside multiError are other multiErrors. -// -// multiError formats to a semi-colon delimited list of error messages with -// %v and with a more readable multi-line format with %+v. -type multiError struct { - copyNeeded atomic.Bool - errors []error - *stack -} - -// Errors returns the list of underlying errors. -// -// This slice MUST NOT be modified. -func (merr *multiError) Errors() []error { - if merr == nil { - return nil - } - return merr.errors -} - -func (merr *multiError) Error() string { - if merr == nil { - return "" - } - - buff := _bufferPool.Get().(*bytes.Buffer) - buff.Reset() - - merr.writeSingleline(buff) - - result := buff.String() - _bufferPool.Put(buff) - return result -} - -// Every compares every error in the given err against the given target error -// using [errors.Is], and returns true only if every comparison returned true. -func Every(err error, target error) bool { - for _, e := range extractErrors(err) { - if !errors.Is(e, target) { - return false - } - } - return true -} - -func (merr *multiError) Format(f fmt.State, c rune) { - if c == 'v' && f.Flag('+') { - merr.writeMultiline(f) - } else { - merr.writeSingleline(f) - } -} - -func (merr *multiError) writeSingleline(w io.Writer) { - first := true - for _, item := range merr.errors { - if first { - first = false - } else { - w.Write(_singlelineSeparator) - } - io.WriteString(w, item.Error()) - } -} - -func (merr *multiError) writeMultiline(w io.Writer) { - w.Write(_multilinePrefix) - for _, item := range merr.errors { - w.Write(_multilineSeparator) - writePrefixLine(w, _multilineIndent, fmt.Sprintf("%+v", item)) - } -} - -// Writes s to the writer with the given prefix added before each line after -// the first. -func writePrefixLine(w io.Writer, prefix []byte, s string) { - first := true - for len(s) > 0 { - if first { - first = false - } else { - w.Write(prefix) - } - - idx := strings.IndexByte(s, '\n') - if idx < 0 { - idx = len(s) - 1 - } - - io.WriteString(w, s[:idx+1]) - s = s[idx+1:] - } -} - -type inspectResult struct { - // Number of top-level non-nil errors - Count int - - // Total number of errors including multiErrors - Capacity int - - // Index of the first non-nil error in the list. Value is meaningless if - // Count is zero. - FirstErrorIdx int - - // Whether the list contains at least one multiError - ContainsMultiError bool -} - -// Inspects the given slice of errors so that we can efficiently allocate -// space for it. -func inspect(errors []error) (res inspectResult) { - first := true - for i, err := range errors { - if err == nil { - continue - } - - res.Count++ - if first { - first = false - res.FirstErrorIdx = i - } - - if merr, ok := err.(*multiError); ok { - res.Capacity += len(merr.errors) - res.ContainsMultiError = true - } else { - res.Capacity++ - } - } - return -} - -// fromSlice converts the given list of errors into a single error. -func fromSlice(errors []error) error { - // Don't pay to inspect small slices. - switch len(errors) { - case 0: - return nil - case 1: - return errors[0] - } - - res := inspect(errors) - switch res.Count { - case 0: - return nil - case 1: - // only one non-nil entry - return errors[res.FirstErrorIdx] - case len(errors): - if !res.ContainsMultiError { - // Error list is flat. Make a copy of it - // Otherwise "errors" escapes to the heap - // unconditionally for all other cases. - // This lets us optimize for the "no errors" case. - out := append(([]error)(nil), errors...) - return &multiError{errors: out, stack: callers(1)} - } - } - - nonNilErrs := make([]error, 0, res.Capacity) - for _, err := range errors[res.FirstErrorIdx:] { - if err == nil { - continue - } - - if nested, ok := err.(*multiError); ok { - nonNilErrs = append(nonNilErrs, nested.errors...) - } else { - nonNilErrs = append(nonNilErrs, err) - } - } - - return &multiError{errors: nonNilErrs, stack: callers(0)} -} - -// Combine combines the passed errors into a single error. -// -// If zero arguments were passed or if all items are nil, a nil error is -// returned. -// -// Combine(nil, nil) // == nil -// -// If only a single error was passed, it is returned as-is. -// -// Combine(err) // == err -// -// Combine skips over nil arguments so this function may be used to combine -// together errors from operations that fail independently of each other. -// -// multierr.Combine( -// reader.Close(), -// writer.Close(), -// pipe.Close(), -// ) -// -// If any of the passed errors is a multierr error, it will be flattened along -// with the other errors. -// -// multierr.Combine(multierr.Combine(err1, err2), err3) -// // is the same as -// multierr.Combine(err1, err2, err3) -// -// The returned error formats into a readable multi-line error message if -// formatted with %+v. -// -// fmt.Sprintf("%+v", multierr.Combine(err1, err2)) -func Combine(errors ...error) error { - return fromSlice(errors) -} - -// Append appends the given errors together. Either value may be nil. -// -// This function is a specialization of Combine for the common case where -// there are only two errors. -// -// err = multierr.Append(reader.Close(), writer.Close()) -// -// The following pattern may also be used to record failure of deferred -// operations without losing information about the original error. -// -// func doSomething(..) (err error) { -// f := acquireResource() -// defer func() { -// err = multierr.Append(err, f.Close()) -// }() -// -// Note that the variable MUST be a named return to append an error to it from -// the defer statement. -func Append(left error, right error) error { - switch { - case left == nil && right == nil: - return nil - case left == nil: - if _, ok := right.(*errz); !ok { - // It's not an errz, so we need to wrap it. - return &errz{stack: callers(0), error: right} - } - return right - case right == nil: - if _, ok := left.(*errz); !ok { - // It's not an errz, so we need to wrap it. - return &errz{stack: callers(0), error: left} - } - return left - } - - if _, ok := right.(*multiError); !ok { - if l, ok := left.(*multiError); ok && !l.copyNeeded.Swap(true) { - // Common case where the error on the left is constantly being - // appended to. - errs := append(l.errors, right) - return &multiError{errors: errs, stack: callers(0)} - } else if !ok { - // Both errors are single errors. - return &multiError{errors: []error{left, right}, stack: callers(0)} - } - } - - // Either right or both, left and right, are multiErrors. Rely on usual - // expensive logic. - errors := [2]error{left, right} - return fromSlice(errors[0:]) -} - -// Unwrap returns a list of errors wrapped by this multierr. -func (merr *multiError) Unwrap() []error { - return merr.Errors() -} - -type multipleErrors interface { - Unwrap() []error -} - -func extractErrors(err error) []error { - if err == nil { - return nil - } - - // check if the given err is an Unwrapable error that - // implements multipleErrors interface. - eg, ok := err.(multipleErrors) - if !ok { - return []error{err} - } - - return append(([]error)(nil), eg.Unwrap()...) -} From bb9153c6313ae700596f6014deb1e7562d6e6d31 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 12:41:17 -0700 Subject: [PATCH 141/195] errz.Skip --- drivers/xlsx/ingest.go | 3 +- libsq/core/errz/errz.go | 21 ++++-- libsq/core/errz/errz_test.go | 74 +++++++++++++++++++-- libsq/core/errz/multierr.go | 1 - libsq/core/errz/multierr_post_go120.go | 23 ------- libsq/core/errz/stack.go | 18 +++++ libsq/core/errz/{errz_types.go => types.go} | 43 +++++------- libsq/core/errz/types_test.go | 19 ++++++ libsq/core/ioz/httpz/opts.go | 2 +- 9 files changed, 139 insertions(+), 65 deletions(-) delete mode 100644 libsq/core/errz/multierr.go delete mode 100644 libsq/core/errz/multierr_post_go120.go rename libsq/core/errz/{errz_types.go => types.go} (63%) create mode 100644 libsq/core/errz/types_test.go diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index e265b2984..0a196b85d 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -284,7 +284,8 @@ func buildSheetTables(ctx context.Context, srcIngestHeader *bool, sheets []*xShe sheetTbl, err := buildSheetTable(gCtx, srcIngestHeader, sheets[i]) if err != nil { - if errz.IsErrNoData(err) { + if errz.Has[*errz.NoDataError](err) { + //if errz.IsErrNoData(err) { // FIXME: remove after testing // If the sheet has no data, we log it and skip it. lg.FromContext(ctx).Warn("Excel sheet has no data", laSheet, sheets[i].name, diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index a6f591652..e574ebf9e 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -15,19 +15,32 @@ import ( "strings" ) +// Opt is a functional option. Use with [Err] or [New]. +type Opt interface { + apply(*errz) +} + // Err annotates err with a stack trace at the point Err was called. // It is equivalent to Wrap(err, ""). If err is nil, Err returns nil. -func Err(err error) error { +func Err(err error, opts ...Opt) error { if err == nil { return nil } - return &errz{stack: callers(0), error: err} + ez := &errz{stack: callers(0), error: err} + for _, opt := range opts { + opt.apply(ez) + } + return ez } // New returns an error with the supplied message, recording the // stack trace at the point it was called. -func New(message string) error { - return &errz{stack: callers(0), msg: message} +func New(message string, opts ...Opt) error { + ez := &errz{stack: callers(0), msg: message} + for _, opt := range opts { + opt.apply(ez) + } + return ez } // Errorf works like [fmt.Errorf], but it also records the stack trace diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 77727b57f..2c6057b6a 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -96,22 +96,22 @@ func TestIsErrNotExist(t *testing.T) { func TestIsErrNoData(t *testing.T) { var err error - require.False(t, errz.IsErrNoData(err)) - require.False(t, errz.IsErrNoData(errz.New("huzzah"))) + require.False(t, errz.Has[*errz.NoDataError](err)) + require.False(t, errz.Has[*errz.NoDataError](errz.New("huzzah"))) var nde1 *errz.NoDataError - require.True(t, errz.IsErrNoData(nde1)) + require.True(t, errz.Has[*errz.NoDataError](nde1)) var nde2 *errz.NoDataError require.True(t, errors.As(nde1, &nde2)) err = errz.NoData(errz.New("huzzah")) - require.True(t, errz.IsErrNoData(err)) + require.True(t, errz.Has[*errz.NoDataError](err)) err = fmt.Errorf("wrap me: %w", err) - require.True(t, errz.IsErrNoData(err)) + require.True(t, errz.Has[*errz.NoDataError](err)) err = errz.NoDataf("%s doesn't exist", "me") - require.True(t, errz.IsErrNoData(err)) + require.True(t, errz.Has[*errz.NoDataError](err)) require.Equal(t, "me doesn't exist", err.Error()) } @@ -144,7 +144,6 @@ func TestIs(t *testing.T) { require.Equal(t, "wrap: "+sql.ErrNoRows.Error(), err.Error()) require.True(t, errors.Is(err, sql.ErrNoRows)) } - func TestStackTrace(t *testing.T) { e1 := errz.New("inner") e2 := errz.Wrap(e1, "wrap") @@ -156,3 +155,64 @@ func TestStackTrace(t *testing.T) { require.NotNil(t, gotFinalStack) require.Equal(t, gotStacks[len(gotStacks)-1], gotFinalStack) } + +func TestOptSkip(t *testing.T) { + err := errz.Wrap(errz.New("inner"), "wrap1") + chain := errz.Chain(err) + require.Len(t, chain, 2) + //t.Logf("\n%+v", errz.LastStack(err).Frames) + + errSkip0 := errz.Err(err, errz.Skip(0)) + errSkip1 := errz.Err(err, errz.Skip(1)) + errSkip2 := errz.Err(err, errz.Skip(2)) + + require.NotNil(t, errSkip0) + require.NotNil(t, errSkip1) + require.NotNil(t, errSkip2) + //chain2 := errz.Chain(err) + //require.Len(t, chain2, 2) + stacks0 := errz.Stacks(errSkip0) + stacks1 := errz.Stacks(errSkip1) + _ = stacks1 + + t.Logf("========== stacks0 ==========") + for _, st := range stacks0 { + t.Logf("\n\n\n\n%+v", st.Frames) + } + t.Logf("========== stacks1 ==========") + for _, st := range stacks1 { + t.Logf("\n\n\n\n%+v", st.Frames) + } + require.Len(t, stacks1[0].Frames, 2) + + lastStack1 := errz.LastStack(errSkip1) + t.Logf("========== lastStack1 ==========") + t.Logf("\n\n\n\n%+v", lastStack1.Frames) + + //t.Logf("\n%+v", errz.LastStack(err).Frames) +} + +type FooErr struct { + msg string +} + +func (e *FooErr) Error() string { + return e.msg +} + +func NewFooError(msg string) error { + //return &FooErr{error: errz.New(msg, errz.Skip(1))} + return errz.Err(&FooErr{msg: msg}, errz.Skip(1)) + +} + +func TestCustomError(t *testing.T) { + err := NewFooError("bah") + t.Logf("err: %v", err) + //st := errz.LastStack(err) + //require.NotNil(t, st) + stacks := errz.Stacks(err) + require.Len(t, stacks, 1) + st := stacks[0] + t.Logf("\n\n\n\n%+v", st.Frames) +} diff --git a/libsq/core/errz/multierr.go b/libsq/core/errz/multierr.go deleted file mode 100644 index f1ed88ea9..000000000 --- a/libsq/core/errz/multierr.go +++ /dev/null @@ -1 +0,0 @@ -package errz diff --git a/libsq/core/errz/multierr_post_go120.go b/libsq/core/errz/multierr_post_go120.go deleted file mode 100644 index c2736c6b7..000000000 --- a/libsq/core/errz/multierr_post_go120.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2017-2023 Uber Technologies, Inc. -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -// THE SOFTWARE. - -//go:build go1.20 - -package errz diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index c525dd160..1c82695f3 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -11,6 +11,24 @@ import ( "strings" ) +var _ Opt = (*Skip)(nil) + +// Skip is an Opt that can be passed to Err or New that +// indicates how many frames to skip when recording the stack trace. +// This is useful when wrapping errors in helper functions. +// +// func handleErr(err error) error { +// slog.Default().Error("Oh noes", "err", err) +// return errz.Err(err, errz.Skip(1)) +// } +// +// Skipping too many frames will panic. +type Skip int + +func (s Skip) apply(e *errz) { + *(e.stack) = (*e.stack)[int(s):] +} + const unknown = "unknown" // Frame represents a program counter inside a stack frame. diff --git a/libsq/core/errz/errz_types.go b/libsq/core/errz/types.go similarity index 63% rename from libsq/core/errz/errz_types.go rename to libsq/core/errz/types.go index ac2463c5b..c98c74e1c 100644 --- a/libsq/core/errz/errz_types.go +++ b/libsq/core/errz/types.go @@ -1,8 +1,8 @@ package errz import ( - "context" "errors" + "fmt" ) // NotExistError indicates that a DB object, such @@ -41,44 +41,31 @@ func IsErrNotExist(err error) bool { // REVISIT: Consider moving NoDataError to libsq/driver? // REVISIT: Consider renaming NoDataError to EmptyDataError? type NoDataError struct { - error + errz } -// Unwrap satisfies the stdlib errors.Unwrap function. -func (e *NoDataError) Unwrap() error { return e.error } +//// Unwrap satisfies the stdlib errors.Unwrap function. +//func (e *NoDataError) Unwrap() error { return e.error } // NoData returns a NoDataError, or nil. func NoData(err error) error { if err == nil { return nil } - return &NoDataError{error: Err(err)} + return &NoDataError{errz{stack: callers(0), error: err}} } // NoDataf returns a NoDataError. func NoDataf(format string, args ...any) error { - return &NoDataError{error: Errorf(format, args...)} -} - -// IsErrNoData returns true if err is non-nil and -// err is or contains NoDataError. -func IsErrNoData(err error) bool { - if err == nil { - return false - } - var e *NoDataError - return errors.As(err, &e) + return &NoDataError{errz: errz{stack: callers(0), msg: fmt.Sprintf(format, args...)}} } -// IsErrContext returns true if err is context.Canceled or context.DeadlineExceeded. -func IsErrContext(err error) bool { - if err == nil { - return false - } - - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return true - } - - return false -} +//// IsErrNoData returns true if err is non-nil and +//// err is or contains NoDataError. +//func IsErrNoData(err error) bool { +// if err == nil { +// return false +// } +// var e *NoDataError +// return errors.As(err, &e) +//} diff --git a/libsq/core/errz/types_test.go b/libsq/core/errz/types_test.go new file mode 100644 index 000000000..7728cac26 --- /dev/null +++ b/libsq/core/errz/types_test.go @@ -0,0 +1,19 @@ +package errz_test + +import ( + "errors" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/stretchr/testify/require" + "testing" +) + +func TestNoData(t *testing.T) { + booErr := errors.New("boo") + ndErr := errz.NoData(booErr) + st := errz.Stacks(ndErr) + require.NotNil(t, st) + + unwrap1 := errors.Unwrap(ndErr) + require.NotNil(t, unwrap1) + require.Equal(t, booErr, unwrap1) +} diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index 2cf7aafa4..93c23795f 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -161,7 +161,7 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { resp, err := errz.Return(next.RoundTrip(req.WithContext(ctx))) - if errz.IsErrContext(err) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { if loz.Take(ctx.Done()) { // The lower-down RoundTripper probably returned ctx.Err(), // not context.Cause(), so we swap it around here. From 38ddcbab9532333817e1595427c7c26d0d6730c4 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 16 Dec 2023 18:56:44 -0700 Subject: [PATCH 142/195] errz functionality should be done --- cli/diff/data_naive.go | 4 +- cli/diff/table.go | 3 +- cli/output/tablew/errorwriter.go | 11 +---- drivers/mysql/errors.go | 3 +- drivers/mysql/metadata.go | 2 +- drivers/postgres/errors.go | 3 +- drivers/sqlite3/errors.go | 3 +- drivers/sqlserver/errors.go | 3 +- drivers/xlsx/ingest.go | 8 ++-- drivers/xlsx/xlsx_test.go | 2 +- libsq/core/errz/errz.go | 49 +++++++++++++++---- libsq/core/errz/errz_test.go | 59 ++++++++--------------- libsq/core/errz/multi.go | 15 ++++-- libsq/core/errz/types.go | 70 ---------------------------- libsq/core/errz/types_test.go | 19 -------- libsq/core/lg/devlog/tint/handler.go | 16 ++----- libsq/driver/driver.go | 31 ++++++++++++ libsq/driver/driver_test.go | 40 +++++++++++++++- 18 files changed, 163 insertions(+), 178 deletions(-) delete mode 100644 libsq/core/errz/types_test.go diff --git a/cli/diff/data_naive.go b/cli/diff/data_naive.go index 673b60ec3..3c2f8ef56 100644 --- a/cli/diff/data_naive.go +++ b/cli/diff/data_naive.go @@ -52,7 +52,7 @@ func buildTableDataDiff(ctx context.Context, ru *run.Run, cfg *Config, g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { if err := libsq.ExecuteSLQ(gCtx, qc, query1, recw1); err != nil { - if errz.IsErrNotExist(err) { + if errz.Has[*driver.NotExistError](err) { // It's totally ok if a table is not found. log.Debug("Diff: table not found", lga.Src, td1.src, lga.Table, td1.tblName) return nil @@ -64,7 +64,7 @@ func buildTableDataDiff(ctx context.Context, ru *run.Run, cfg *Config, }) g.Go(func() error { if err := libsq.ExecuteSLQ(gCtx, qc, query2, recw2); err != nil { - if errz.IsErrNotExist(err) { + if errz.Has[*driver.NotExistError](err) { log.Debug("Diff: table not found", lga.Src, td2.src, lga.Table, td2.tblName) return nil } diff --git a/cli/diff/table.go b/cli/diff/table.go index 320a97983..2dbf71b8e 100644 --- a/cli/diff/table.go +++ b/cli/diff/table.go @@ -3,6 +3,7 @@ package diff import ( "context" "fmt" + "github.com/neilotoole/sq/libsq/driver" "golang.org/x/sync/errgroup" @@ -120,7 +121,7 @@ func fetchTableMeta(ctx context.Context, ru *run.Run, src *source.Source, table } md, err := grip.TableMetadata(ctx, table) if err != nil { - if errz.IsErrNotExist(err) { + if errz.Has[*driver.NotExistError](err) { return nil, nil //nolint:nilnil } return nil, err diff --git a/cli/output/tablew/errorwriter.go b/cli/output/tablew/errorwriter.go index d1568c63a..9df2d582a 100644 --- a/cli/output/tablew/errorwriter.go +++ b/cli/output/tablew/errorwriter.go @@ -3,7 +3,6 @@ package tablew import ( "bytes" "fmt" - "github.com/neilotoole/sq/libsq/core/stringz" "io" "strings" @@ -53,15 +52,7 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { } if stack.Error != nil { - errTypes := stringz.TypeNames(errz.Chain(stack.Error)...) - for i, typ := range errTypes { - w.pr.StackErrorType.Fprint(buf, typ) - if i < len(errTypes)-1 { - w.pr.Faint.Fprint(buf, ":") - buf.WriteByte(' ') - } - } - buf.WriteByte('\n') + w.pr.StackErrorType.Fprintln(buf, errz.SprintTreeTypes(stack.Error)) w.pr.StackError.Fprintln(buf, stack.Error.Error()) } diff --git a/drivers/mysql/errors.go b/drivers/mysql/errors.go index e46adf78d..fdf629969 100644 --- a/drivers/mysql/errors.go +++ b/drivers/mysql/errors.go @@ -2,6 +2,7 @@ package mysql import ( "errors" + "github.com/neilotoole/sq/libsq/driver" "github.com/go-sql-driver/mysql" @@ -17,7 +18,7 @@ func errw(err error) error { case err == nil: return nil case hasErrCode(err, errNumTableNotExist): - return errz.NotExist(err) + return driver.NewNotExistError(err) default: return errz.Err(err) } diff --git a/drivers/mysql/metadata.go b/drivers/mysql/metadata.go index fcbc3111a..6379765a8 100644 --- a/drivers/mysql/metadata.go +++ b/drivers/mysql/metadata.go @@ -591,7 +591,7 @@ func getTableRowCounts(ctx context.Context, db sqlz.DB, tblNames []string) (map[ } err = errw(err) - if errz.IsErrNotExist(err) { + if errz.Has[*driver.NotExistError](err) { // Sometimes a table can get deleted during the operation. If so, // we just remove that table from the list, and try again. // We could also do this entire thing in a transaction, but where's diff --git a/drivers/postgres/errors.go b/drivers/postgres/errors.go index 3871aab0f..ce278ed18 100644 --- a/drivers/postgres/errors.go +++ b/drivers/postgres/errors.go @@ -2,6 +2,7 @@ package postgres import ( "errors" + "github.com/neilotoole/sq/libsq/driver" "github.com/jackc/pgx/v5/pgconn" @@ -17,7 +18,7 @@ func errw(err error) error { case err == nil: return nil case hasErrCode(err, errCodeRelationNotExist): - return errz.NotExist(err) + return driver.NewNotExistError(err) default: return errz.Err(err) } diff --git a/drivers/sqlite3/errors.go b/drivers/sqlite3/errors.go index 06d83e211..60def3f6f 100644 --- a/drivers/sqlite3/errors.go +++ b/drivers/sqlite3/errors.go @@ -1,6 +1,7 @@ package sqlite3 import ( + "github.com/neilotoole/sq/libsq/driver" "strings" "github.com/neilotoole/sq/libsq/core/errz" @@ -17,7 +18,7 @@ func errw(err error) error { case strings.HasPrefix(err.Error(), "no such table:"): // The sqlite driver always returns sqlite3.ErrError(1), so // we need to search by string. Needs further investigation. - return errz.NotExist(err) + return driver.NewNotExistError(err) default: return errz.Err(err) } diff --git a/drivers/sqlserver/errors.go b/drivers/sqlserver/errors.go index 68855146c..673ae40dc 100644 --- a/drivers/sqlserver/errors.go +++ b/drivers/sqlserver/errors.go @@ -2,6 +2,7 @@ package sqlserver import ( "errors" + "github.com/neilotoole/sq/libsq/driver" mssql "github.com/microsoft/go-mssqldb" @@ -42,7 +43,7 @@ func errw(err error) error { case err == nil: return nil case hasErrCode(err, errCodeBadObject): - return errz.NotExist(err) + return driver.NewNotExistError(err) default: return errz.Err(err) } diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 0a196b85d..717c0dc38 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -284,7 +284,7 @@ func buildSheetTables(ctx context.Context, srcIngestHeader *bool, sheets []*xShe sheetTbl, err := buildSheetTable(gCtx, srcIngestHeader, sheets[i]) if err != nil { - if errz.Has[*errz.NoDataError](err) { + if errz.Has[*driver.EmptyDataError](err) { //if errz.IsErrNoData(err) { // FIXME: remove after testing // If the sheet has no data, we log it and skip it. lg.FromContext(ctx).Warn("Excel sheet has no data", @@ -324,7 +324,7 @@ func getSrcIngestHeader(o options.Options) *bool { // a model of the table, or an error. If the sheet is empty, (nil,nil) // is returned. If srcIngestHeader is nil, the function attempts // to detect if the sheet has a header row. -// If the sheet has no data, errz.NoDataError is returned. +// If the sheet has no data, errz.EmptyDataError is returned. func buildSheetTable(ctx context.Context, srcIngestHeader *bool, sheet *xSheet) (*sheetTable, error) { log := lg.FromContext(ctx) @@ -334,11 +334,11 @@ func buildSheetTable(ctx context.Context, srcIngestHeader *bool, sheet *xSheet) } if len(sheet.sampleRows) == 0 { - return nil, errz.NoDataf("excel: sheet {%s} has no row data", sheet.name) + return nil, driver.NewEmptyDataError("excel: sheet {%s} has no row data", sheet.name) } if sheet.sampleRowsMaxWidth == 0 { - return nil, errz.NoDataf("excel: sheet {%s} has no column data", sheet.name) + return nil, driver.NewEmptyDataError("excel: sheet {%s} has no column data", sheet.name) } var hasHeader bool diff --git a/drivers/xlsx/xlsx_test.go b/drivers/xlsx/xlsx_test.go index 50ef6b069..ceb2e9ab1 100644 --- a/drivers/xlsx/xlsx_test.go +++ b/drivers/xlsx/xlsx_test.go @@ -294,7 +294,7 @@ func TestHandleSomeSheetsEmpty(t *testing.T) { for _, tblName := range []string{"Sheet2Empty, Sheet3Empty"} { _, err = th.TableMetadata(src, tblName) require.Error(t, err) - require.True(t, errz.IsErrNotExist(err)) + require.True(t, errz.Has[*driver.NotExistError](err)) } } diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index e574ebf9e..037f3671e 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -225,7 +225,7 @@ func Return[T any](t T, err error) (T, error) { // require.True(t, ok) // require.Equal(t, "non-existing", pathErr.Path) // -// If err is nil, As returns false. +// If err is nil, As returns false. See also: [errz.Has]. func As[E error](err error) (E, bool) { var target E if err == nil { @@ -238,18 +238,17 @@ func As[E error](err error) (E, bool) { return target, false } -// Has returns true if err, or an error in its error tree, if of type E. +// Has returns true if err, or an error in its error tree, matches error type E. +// An error is considered a match by the rules of [errors.As] // -// _, err := os.Open("non-existing") -// isPathErr := errz.Has[*fs.PathError](err) +// f, err := os.Open("non-existing") +// if errz.Has[*fs.PathError](err) { +// // Do something +// } // -// If err is nil, Has returns false. +// If err is nil, Has returns false. See also: [errz.As]. func Has[E error](err error) bool { - if err == nil { - return false - } - var target E - return errors.As(err, &target) + return errors.As(err, new(E)) } // Chain returns a slice of all the errors in err's tree. @@ -266,3 +265,33 @@ func Chain(err error) []error { return errs } + +// SprintTreeTypes returns a string representation of err's type tree. +// A multi-error is represented as a slice of its children. +func SprintTreeTypes(err error) string { + if err == nil { + return "" + } + errChain := Chain(err) + var sb strings.Builder + for i, e := range errChain { + sb.WriteString(fmt.Sprintf("%T", e)) + if me, ok := e.(multipleErrorer); ok { + children := me.Unwrap() + childText := make([]string, len(children)) + for j := range children { + childText[j] = SprintTreeTypes(children[j]) + } + joined := strings.Join(childText, ", ") + sb.WriteRune('[') + sb.WriteString(joined) + sb.WriteRune(']') + } + + if i < len(errChain)-1 { + sb.WriteString(": ") + } + } + + return sb.String() +} diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 2c6057b6a..058814dc1 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -1,9 +1,9 @@ package errz_test import ( + "context" "database/sql" "errors" - "fmt" "io/fs" "net/url" "os" @@ -77,44 +77,6 @@ func TestLogValue(t *testing.T) { log.Debug("via errz.Wrap", "err", wrapErr) } -func TestIsErrNotExist(t *testing.T) { - var err error - require.False(t, errz.IsErrNotExist(err)) - require.False(t, errz.IsErrNotExist(errz.New("huzzah"))) - - var nee1 *errz.NotExistError - require.True(t, errz.IsErrNotExist(nee1)) - - var nee2 *errz.NotExistError - require.True(t, errors.As(nee1, &nee2)) - - err = errz.NotExist(errz.New("huzzah")) - require.True(t, errz.IsErrNotExist(err)) - err = fmt.Errorf("wrap me: %w", err) - require.True(t, errz.IsErrNotExist(err)) -} - -func TestIsErrNoData(t *testing.T) { - var err error - require.False(t, errz.Has[*errz.NoDataError](err)) - require.False(t, errz.Has[*errz.NoDataError](errz.New("huzzah"))) - - var nde1 *errz.NoDataError - require.True(t, errz.Has[*errz.NoDataError](nde1)) - - var nde2 *errz.NoDataError - require.True(t, errors.As(nde1, &nde2)) - - err = errz.NoData(errz.New("huzzah")) - require.True(t, errz.Has[*errz.NoDataError](err)) - err = fmt.Errorf("wrap me: %w", err) - require.True(t, errz.Has[*errz.NoDataError](err)) - - err = errz.NoDataf("%s doesn't exist", "me") - require.True(t, errz.Has[*errz.NoDataError](err)) - require.Equal(t, "me doesn't exist", err.Error()) -} - func TestHas(t *testing.T) { _, err := os.Open(stringz.Uniq32() + "-non-existing") require.Error(t, err) @@ -125,6 +87,10 @@ func TestHas(t *testing.T) { got = errz.Has[*url.Error](err) require.False(t, got) + + got = errz.Has[*url.Error](nil) + require.False(t, got) + } func TestAs(t *testing.T) { @@ -216,3 +182,18 @@ func TestCustomError(t *testing.T) { st := stacks[0] t.Logf("\n\n\n\n%+v", st.Frames) } + +//nolint:lll +func TestSprintTreeTypes(t *testing.T) { + err := errz.Wrap(errz.Wrap(errz.New("inner"), "wrap1"), "") + require.Equal(t, "wrap1: inner", err.Error()) + + got := errz.SprintTreeTypes(err) + require.Equal(t, "*errz.errz: *errz.errz: *errz.errz", got) + + me := errz.Combine(context.DeadlineExceeded, err, sql.ErrNoRows) + err = errz.Wrap(me, "wrap3") + got = errz.SprintTreeTypes(err) + + require.Equal(t, "*errz.errz: *errz.multiError[context.deadlineExceededError, *errz.errz: *errz.errz: *errz.errz, *errors.errorString]", got) +} diff --git a/libsq/core/errz/multi.go b/libsq/core/errz/multi.go index 9f32822d2..e5753876a 100644 --- a/libsq/core/errz/multi.go +++ b/libsq/core/errz/multi.go @@ -493,7 +493,7 @@ func (merr *multiError) Unwrap() []error { return merr.Errors() } -type multipleErrors interface { +type multipleErrorer interface { Unwrap() []error } @@ -503,11 +503,20 @@ func extractErrors(err error) []error { } // check if the given err is an Unwrapable error that - // implements multipleErrors interface. - eg, ok := err.(multipleErrors) + // implements multipleErrorer interface. + eg, ok := err.(multipleErrorer) if !ok { return []error{err} } return append(([]error)(nil), eg.Unwrap()...) } + +func IsMulti(err error) bool { + if err == nil { + return false + } + + _, ok := err.(*multiError) + return ok +} diff --git a/libsq/core/errz/types.go b/libsq/core/errz/types.go index c98c74e1c..f1ed88ea9 100644 --- a/libsq/core/errz/types.go +++ b/libsq/core/errz/types.go @@ -1,71 +1 @@ package errz - -import ( - "errors" - "fmt" -) - -// NotExistError indicates that a DB object, such -// as a table, does not exist. -// -// REVISIT: Consider moving NotExistError to libsq/driver? -type NotExistError struct { - error -} - -// Unwrap satisfies the stdlib errors.Unwrap function. -func (e *NotExistError) Unwrap() error { return e.error } - -// NotExist returns a NotExistError, or nil. -func NotExist(err error) error { - if err == nil { - return nil - } - return &NotExistError{error: Err(err)} -} - -// IsErrNotExist returns true if err is non-nil and -// err is or contains NotExistError. -func IsErrNotExist(err error) bool { - if err == nil { - return false - } - var e *NotExistError - return errors.As(err, &e) -} - -// NoDataError indicates that there's no data, e.g. an empty document. -// This is subtly different to NotExistError, which would indicate that -// the document doesn't exist. -// -// REVISIT: Consider moving NoDataError to libsq/driver? -// REVISIT: Consider renaming NoDataError to EmptyDataError? -type NoDataError struct { - errz -} - -//// Unwrap satisfies the stdlib errors.Unwrap function. -//func (e *NoDataError) Unwrap() error { return e.error } - -// NoData returns a NoDataError, or nil. -func NoData(err error) error { - if err == nil { - return nil - } - return &NoDataError{errz{stack: callers(0), error: err}} -} - -// NoDataf returns a NoDataError. -func NoDataf(format string, args ...any) error { - return &NoDataError{errz: errz{stack: callers(0), msg: fmt.Sprintf(format, args...)}} -} - -//// IsErrNoData returns true if err is non-nil and -//// err is or contains NoDataError. -//func IsErrNoData(err error) bool { -// if err == nil { -// return false -// } -// var e *NoDataError -// return errors.As(err, &e) -//} diff --git a/libsq/core/errz/types_test.go b/libsq/core/errz/types_test.go deleted file mode 100644 index 7728cac26..000000000 --- a/libsq/core/errz/types_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package errz_test - -import ( - "errors" - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/stretchr/testify/require" - "testing" -) - -func TestNoData(t *testing.T) { - booErr := errors.New("boo") - ndErr := errz.NoData(booErr) - st := errz.Stacks(ndErr) - require.NotNil(t, st) - - unwrap1 := errors.Unwrap(ndErr) - require.NotNil(t, unwrap1) - require.Equal(t, booErr, unwrap1) -} diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index d38d0fab2..8c5d2bedc 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -56,7 +56,6 @@ import ( "context" "encoding" "fmt" - "github.com/neilotoole/sq/libsq/core/stringz" "io" "log/slog" "path/filepath" @@ -322,18 +321,9 @@ func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { } if stack.Error != nil { - errTypes := stringz.TypeNames(errz.Chain(stack.Error)...) - for j, typ := range errTypes { - buf.WriteStringIf(!h.noColor, ansiStackErrType) - buf.WriteString(typ) - buf.WriteStringIf(!h.noColor, ansiReset) - if j < len(errTypes)-1 { - buf.WriteStringIf(!h.noColor, ansiFaint) - buf.WriteByte(':') - buf.WriteStringIf(!h.noColor, ansiResetFaint) - buf.WriteByte(' ') - } - } + buf.WriteStringIf(!h.noColor, ansiStackErrType) + buf.WriteString(errz.SprintTreeTypes(stack.Error)) + buf.WriteStringIf(!h.noColor, ansiReset) buf.WriteByte('\n') buf.WriteStringIf(!h.noColor, ansiStackErr) diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index 0bc1f5a56..fc6951ba5 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -3,6 +3,7 @@ package driver import ( "context" "database/sql" + "fmt" "github.com/neilotoole/sq/libsq/ast/render" "github.com/neilotoole/sq/libsq/core/errz" @@ -231,3 +232,33 @@ func OpeningPing(ctx context.Context, src *source.Source, db *sql.DB) error { return nil } + +// EmptyDataError indicates that there's no data, e.g. an empty document. +// This is subtly different to NotExistError, which would indicate that +// the document doesn't exist. +type EmptyDataError string + +// Error satisfies the stdlib error interface. +func (e EmptyDataError) Error() string { return string(e) } + +// NewEmptyDataError returns a EmptyDataError. +func NewEmptyDataError(format string, args ...any) error { + return errz.Err(EmptyDataError(fmt.Sprintf(format, args...)), errz.Skip(1)) +} + +// NotExistError indicates that a DB object, such +// as a table, does not exist. +type NotExistError struct { + error +} + +// Unwrap satisfies the stdlib errors.Unwrap function. +func (e *NotExistError) Unwrap() error { return e.error } + +// NewNotExistError returns a NotExistError, or nil. +func NewNotExistError(err error) error { + if err == nil { + return nil + } + return errz.Err(&NotExistError{error: err}) +} diff --git a/libsq/driver/driver_test.go b/libsq/driver/driver_test.go index 26a25c17e..b610ffddc 100644 --- a/libsq/driver/driver_test.go +++ b/libsq/driver/driver_test.go @@ -2,6 +2,7 @@ package driver_test import ( "context" + "errors" "fmt" "testing" @@ -729,7 +730,7 @@ func TestSQLDriver_ErrWrap_IsErrNotExist(t *testing.T) { th, _, _, _, _ := testh.NewWith(t, h) _, err := th.QuerySLQ(h+".does_not_exist", nil) require.Error(t, err) - require.True(t, errz.IsErrNotExist(err)) + require.True(t, errz.Has[*driver.NotExistError](err)) }) } } @@ -754,3 +755,40 @@ func TestMungeColNames(t *testing.T) { }) } } + +func TestEmptyDataError(t *testing.T) { + var err error + require.False(t, errz.Has[driver.EmptyDataError](err)) + require.False(t, errz.Has[driver.EmptyDataError](errz.New("huzzah"))) + + var ede1 driver.EmptyDataError + require.True(t, errz.Has[driver.EmptyDataError](ede1)) + + var ede2 driver.EmptyDataError + require.True(t, errors.As(ede1, &ede2)) + + err = driver.NewEmptyDataError("huzzah") + require.True(t, errz.Has[driver.EmptyDataError](err)) + err = fmt.Errorf("wrap me: %w", err) + require.True(t, errz.Has[driver.EmptyDataError](err)) + + err = driver.NewEmptyDataError("%s doesn't exist", "me") + require.True(t, errz.Has[driver.EmptyDataError](err)) + require.Equal(t, "me doesn't exist", err.Error()) +} + +func TestNotExistError(t *testing.T) { + var err error + require.False(t, errz.Has[*driver.NotExistError](err)) + + var nee1 *driver.NotExistError + require.True(t, errz.Has[*driver.NotExistError](nee1)) + + var nee2 *driver.NotExistError + require.True(t, errors.As(nee1, &nee2)) + + err = driver.NewNotExistError(errz.New("huzzah")) + require.True(t, errz.Has[*driver.NotExistError](err)) + err = fmt.Errorf("wrap me: %w", err) + require.True(t, errz.Has[*driver.NotExistError](err)) +} From 585a157d1a63ab199afd2dd3be64ace260bd921d Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 17 Dec 2023 06:45:26 -0700 Subject: [PATCH 143/195] start errz wrapup --- cli/cmd_x.go | 6 ++++-- cli/logging.go | 5 ----- cli/output/jsonw/errorwriter.go | 11 ++++++----- libsq/core/errz/multi.go | 16 ++++++++++++++++ libsq/core/errz/stack.go | 29 ----------------------------- 5 files changed, 26 insertions(+), 41 deletions(-) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index eccdb8f29..265bffe7f 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -193,8 +193,10 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { err1 := h.Errors[0] err2 := errz.New("another err") - return errz.Combine(err1, err2) - + err3 := errz.Combine(err1, err2) + //lg.FromContext(ctx).Error("OH NO", lga.Err, err3) + return err3 + //return nil case len(h.WriteErrors) > 0: return h.WriteErrors[0] case len(h.CachedFiles) > 0: diff --git a/cli/logging.go b/cli/logging.go index ed8f11e56..c59514c72 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -146,11 +146,6 @@ func slogReplaceSource(_ []string, a slog.Attr) slog.Attr { val += ":" + strconv.Itoa(source.Line) a.Value = slog.StringValue(val) } - // source.File = filepath.Base(source.File) - - // src, ok := a.Value. - // fp := a.Value.String() - // a.Value = slog.StringValue(filepath.Join(filepath.Base(filepath.Dir(fp)), filepath.Base(fp))) } return a } diff --git a/cli/output/jsonw/errorwriter.go b/cli/output/jsonw/errorwriter.go index 6fb7a3386..36ec3e03c 100644 --- a/cli/output/jsonw/errorwriter.go +++ b/cli/output/jsonw/errorwriter.go @@ -2,7 +2,6 @@ package jsonw import ( "fmt" - "github.com/neilotoole/sq/libsq/core/stringz" "io" "log/slog" "strings" @@ -26,12 +25,13 @@ func NewErrorWriter(log *slog.Logger, out io.Writer, pr *output.Printing) output type errorDetail struct { Error string `json:"error,"` BaseError string `json:"base_error,omitempty"` + Tree string `json:"tree,omitempty"` Stack []*stack `json:"stack,omitempty"` } type stackError struct { - Message string `json:"msg"` - Tree []string `json:"tree,omitempty"` + Message string `json:"msg"` + Tree string `json:"tree,omitempty"` } type stack struct { @@ -53,6 +53,7 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { ed := errorDetail{ Error: humanErr.Error(), BaseError: systemErr.Error(), + Tree: errz.SprintTreeTypes(systemErr), } stacks := errz.Stacks(systemErr) @@ -63,10 +64,10 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { } st := &stack{ - Trace: strings.ReplaceAll(fmt.Sprintf("%+v", sysStack), "\n\t", "\n "), + Trace: strings.ReplaceAll(fmt.Sprintf("%+v", sysStack.Frames), "\n\t", "\n "), Error: &stackError{ Message: sysStack.Error.Error(), - Tree: stringz.TypeNames(errz.Chain(sysStack.Error)...), + Tree: errz.SprintTreeTypes(sysStack.Error), }} ed.Stack = append(ed.Stack, st) diff --git a/libsq/core/errz/multi.go b/libsq/core/errz/multi.go index e5753876a..09ab48ee7 100644 --- a/libsq/core/errz/multi.go +++ b/libsq/core/errz/multi.go @@ -6,7 +6,9 @@ import ( "bytes" "errors" "fmt" + "github.com/samber/lo" "io" + "log/slog" "strings" "sync" "sync/atomic" @@ -266,6 +268,20 @@ func Every(err error, target error) bool { return true } +func (merr *multiError) LogValue() slog.Value { + if merr == nil { + return slog.Value{} + } + + attrs := make([]slog.Attr, 4) + attrs[0] = slog.String("msg", merr.Error()) + attrs[1] = slog.String("type", fmt.Sprintf("%T", merr)) + attrs[2] = slog.String("tree", SprintTreeTypes(merr)) + errs := lo.Map(merr.Errors(), func(err error, i int) string { return err.Error() }) + attrs[3] = slog.Any("errors", errs) + return slog.GroupValue(attrs...) +} + func (merr *multiError) Format(f fmt.State, c rune) { if c == 'v' && f.Flag('+') { merr.writeMultiline(f) diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index 1c82695f3..f13410982 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -294,32 +294,3 @@ func LastStack(err error) *StackTrace { return nil } - -// -//func LastStackLegacy(err error) *StackTrace { // FIXME: delete this -// if err == nil { -// return nil -// } -// -// var ez *errz -// var ok bool -// var tracer stackTracer -// for err != nil { -// tracer, ok = err.(stackTracer) -// if !ok || ez == nil { -// return nil -// } -// -// if ez.error == nil { -// return ez.stackTrace() -// } -// -// if _, ok = ez.error.(*errz); !ok { -// return ez.stackTrace() -// } -// -// err = ez.error -// } -// -// return nil -//} From 80cfb9cb90a285accd0976e008dce11ca6164892 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 17 Dec 2023 06:57:33 -0700 Subject: [PATCH 144/195] errz is linted --- libsq/core/errz/errz.go | 10 +-- libsq/core/errz/errz_test.go | 38 +++-------- libsq/core/errz/internal_test.go | 4 +- libsq/core/errz/multi.go | 105 ++++++++++++++----------------- libsq/core/errz/multi_test.go | 10 ++- libsq/core/errz/stack.go | 11 ++-- 6 files changed, 73 insertions(+), 105 deletions(-) diff --git a/libsq/core/errz/errz.go b/libsq/core/errz/errz.go index 037f3671e..819178c3a 100644 --- a/libsq/core/errz/errz.go +++ b/libsq/core/errz/errz.go @@ -121,7 +121,7 @@ func (e *errz) alienCause() error { inner := e.error for inner != nil { // Note: don't use errors.As here; we want the direct type assertion. - if v, ok := inner.(*errz); ok { + if v, ok := inner.(*errz); ok { //nolint:errorlint inner = v.error continue } @@ -142,10 +142,10 @@ func (e *errz) Format(s fmt.State, verb rune) { _, _ = io.WriteString(s, e.msg) e.stack.Format(s, verb) return - } else { - _, _ = fmt.Fprintf(s, "%+v", e.error) - e.stack.Format(s, verb) } + _, _ = fmt.Fprintf(s, "%+v", e.error) + e.stack.Format(s, verb) + return } fallthrough @@ -276,7 +276,7 @@ func SprintTreeTypes(err error) string { var sb strings.Builder for i, e := range errChain { sb.WriteString(fmt.Sprintf("%T", e)) - if me, ok := e.(multipleErrorer); ok { + if me, ok := e.(multipleErrorer); ok { //nolint:errorlint children := me.Unwrap() childText := make([]string, len(children)) for j := range children { diff --git a/libsq/core/errz/errz_test.go b/libsq/core/errz/errz_test.go index 058814dc1..f32f7ec84 100644 --- a/libsq/core/errz/errz_test.go +++ b/libsq/core/errz/errz_test.go @@ -90,7 +90,6 @@ func TestHas(t *testing.T) { got = errz.Has[*url.Error](nil) require.False(t, got) - } func TestAs(t *testing.T) { @@ -110,6 +109,7 @@ func TestIs(t *testing.T) { require.Equal(t, "wrap: "+sql.ErrNoRows.Error(), err.Error()) require.True(t, errors.Is(err, sql.ErrNoRows)) } + func TestStackTrace(t *testing.T) { e1 := errz.New("inner") e2 := errz.Wrap(e1, "wrap") @@ -126,7 +126,6 @@ func TestOptSkip(t *testing.T) { err := errz.Wrap(errz.New("inner"), "wrap1") chain := errz.Chain(err) require.Len(t, chain, 2) - //t.Logf("\n%+v", errz.LastStack(err).Frames) errSkip0 := errz.Err(err, errz.Skip(0)) errSkip1 := errz.Err(err, errz.Skip(1)) @@ -135,52 +134,29 @@ func TestOptSkip(t *testing.T) { require.NotNil(t, errSkip0) require.NotNil(t, errSkip1) require.NotNil(t, errSkip2) - //chain2 := errz.Chain(err) - //require.Len(t, chain2, 2) - stacks0 := errz.Stacks(errSkip0) stacks1 := errz.Stacks(errSkip1) - _ = stacks1 - - t.Logf("========== stacks0 ==========") - for _, st := range stacks0 { - t.Logf("\n\n\n\n%+v", st.Frames) - } - t.Logf("========== stacks1 ==========") - for _, st := range stacks1 { - t.Logf("\n\n\n\n%+v", st.Frames) - } require.Len(t, stacks1[0].Frames, 2) - - lastStack1 := errz.LastStack(errSkip1) - t.Logf("========== lastStack1 ==========") - t.Logf("\n\n\n\n%+v", lastStack1.Frames) - - //t.Logf("\n%+v", errz.LastStack(err).Frames) } -type FooErr struct { +type FooError struct { msg string } -func (e *FooErr) Error() string { +func (e *FooError) Error() string { return e.msg } func NewFooError(msg string) error { - //return &FooErr{error: errz.New(msg, errz.Skip(1))} - return errz.Err(&FooErr{msg: msg}, errz.Skip(1)) - + return errz.Err(&FooError{msg: msg}, errz.Skip(1)) } -func TestCustomError(t *testing.T) { +func TestFooError(t *testing.T) { err := NewFooError("bah") t.Logf("err: %v", err) - //st := errz.LastStack(err) - //require.NotNil(t, st) stacks := errz.Stacks(err) require.Len(t, stacks, 1) st := stacks[0] - t.Logf("\n\n\n\n%+v", st.Frames) + t.Logf("\n%+v", st.Frames) } //nolint:lll @@ -195,5 +171,5 @@ func TestSprintTreeTypes(t *testing.T) { err = errz.Wrap(me, "wrap3") got = errz.SprintTreeTypes(err) - require.Equal(t, "*errz.errz: *errz.multiError[context.deadlineExceededError, *errz.errz: *errz.errz: *errz.errz, *errors.errorString]", got) + require.Equal(t, "*errz.errz: *errz.multiErr[context.deadlineExceededError, *errz.errz: *errz.errz: *errz.errz, *errors.errorString]", got) } diff --git a/libsq/core/errz/internal_test.go b/libsq/core/errz/internal_test.go index f7ba89f83..27dadd73c 100644 --- a/libsq/core/errz/internal_test.go +++ b/libsq/core/errz/internal_test.go @@ -2,10 +2,12 @@ package errz import ( "context" - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) +//nolint:errorlint func TestAlienCause(t *testing.T) { err := New("boo") diff --git a/libsq/core/errz/multi.go b/libsq/core/errz/multi.go index 09ab48ee7..b7a45819e 100644 --- a/libsq/core/errz/multi.go +++ b/libsq/core/errz/multi.go @@ -6,12 +6,13 @@ import ( "bytes" "errors" "fmt" - "github.com/samber/lo" "io" "log/slog" "strings" "sync" "sync/atomic" + + "github.com/samber/lo" ) // Copyright (c) 2017-2023 Uber Technologies, Inc. @@ -159,7 +160,7 @@ var ( // Separator for single-line error messages. _singlelineSeparator = []byte("; ") - // Prefix for multi-line messages + // Prefix for multi-line messages. _multilinePrefix = []byte("the following errors occurred:") // Prefix for the first and following lines of an item in a list of @@ -185,10 +186,6 @@ var _bufferPool = sync.Pool{ }, } -type errorGroup interface { - Errors() []error -} - // Errors returns a slice containing zero or more errors that the supplied // error is composed of. If the error is nil, a nil slice is returned. // @@ -203,24 +200,24 @@ func Errors(err error) []error { return extractErrors(err) } -// multiError is an error that holds one or more errors. +// multiErr is an error that holds one or more errors. // // An instance of this is guaranteed to be non-empty and flattened. That is, -// none of the errors inside multiError are other multiErrors. +// none of the errors inside multiErr are other multiErrors. // -// multiError formats to a semi-colon delimited list of error messages with +// multiErr formats to a semicolon delimited list of error messages with // %v and with a more readable multi-line format with %+v. -type multiError struct { +type multiErr struct { //nolint:errname copyNeeded atomic.Bool errors []error *stack } // inner implements stackTracer. -func (merr *multiError) inner() error { return nil } +func (merr *multiErr) inner() error { return nil } // stackTrace implements stackTracer. -func (merr *multiError) stackTrace() *StackTrace { +func (merr *multiErr) stackTrace() *StackTrace { if merr == nil || merr.stack == nil { return nil } @@ -235,19 +232,19 @@ func (merr *multiError) stackTrace() *StackTrace { // Errors returns the list of underlying errors. // // This slice MUST NOT be modified. -func (merr *multiError) Errors() []error { +func (merr *multiErr) Errors() []error { if merr == nil { return nil } return merr.errors } -func (merr *multiError) Error() string { +func (merr *multiErr) Error() string { if merr == nil { return "" } - buff := _bufferPool.Get().(*bytes.Buffer) + buff, _ := _bufferPool.Get().(*bytes.Buffer) buff.Reset() merr.writeSingleline(buff) @@ -257,18 +254,8 @@ func (merr *multiError) Error() string { return result } -// Every compares every error in the given err against the given target error -// using [errors.Is], and returns true only if every comparison returned true. -func Every(err error, target error) bool { - for _, e := range extractErrors(err) { - if !errors.Is(e, target) { - return false - } - } - return true -} - -func (merr *multiError) LogValue() slog.Value { +// LogValue implements [slog.LogValuer]. +func (merr *multiErr) LogValue() slog.Value { if merr == nil { return slog.Value{} } @@ -282,7 +269,7 @@ func (merr *multiError) LogValue() slog.Value { return slog.GroupValue(attrs...) } -func (merr *multiError) Format(f fmt.State, c rune) { +func (merr *multiErr) Format(f fmt.State, c rune) { if c == 'v' && f.Flag('+') { merr.writeMultiline(f) } else { @@ -290,22 +277,22 @@ func (merr *multiError) Format(f fmt.State, c rune) { } } -func (merr *multiError) writeSingleline(w io.Writer) { +func (merr *multiErr) writeSingleline(w io.Writer) { first := true for _, item := range merr.errors { if first { first = false } else { - w.Write(_singlelineSeparator) + _, _ = w.Write(_singlelineSeparator) } - io.WriteString(w, item.Error()) + _, _ = io.WriteString(w, item.Error()) } } -func (merr *multiError) writeMultiline(w io.Writer) { - w.Write(_multilinePrefix) +func (merr *multiErr) writeMultiline(w io.Writer) { + _, _ = w.Write(_multilinePrefix) for _, item := range merr.errors { - w.Write(_multilineSeparator) + _, _ = w.Write(_multilineSeparator) writePrefixLine(w, _multilineIndent, fmt.Sprintf("%+v", item)) } } @@ -318,7 +305,7 @@ func writePrefixLine(w io.Writer, prefix []byte, s string) { if first { first = false } else { - w.Write(prefix) + _, _ = w.Write(prefix) } idx := strings.IndexByte(s, '\n') @@ -326,7 +313,7 @@ func writePrefixLine(w io.Writer, prefix []byte, s string) { idx = len(s) - 1 } - io.WriteString(w, s[:idx+1]) + _, _ = io.WriteString(w, s[:idx+1]) s = s[idx+1:] } } @@ -342,7 +329,7 @@ type inspectResult struct { // Count is zero. FirstErrorIdx int - // Whether the list contains at least one multiError + // Whether the list contains at least one multiErr ContainsMultiError bool } @@ -361,14 +348,14 @@ func inspect(errors []error) (res inspectResult) { res.FirstErrorIdx = i } - if merr, ok := err.(*multiError); ok { + if merr, ok := err.(*multiErr); ok { //nolint:errorlint res.Capacity += len(merr.errors) res.ContainsMultiError = true } else { res.Capacity++ } } - return + return res } // fromSlice converts the given list of errors into a single error. @@ -395,7 +382,7 @@ func fromSlice(errors []error) error { // unconditionally for all other cases. // This lets us optimize for the "no errors" case. out := append(([]error)(nil), errors...) - return &multiError{errors: out, stack: callers(1)} + return &multiErr{errors: out, stack: callers(1)} } } @@ -405,14 +392,14 @@ func fromSlice(errors []error) error { continue } - if nested, ok := err.(*multiError); ok { + if nested, ok := err.(*multiErr); ok { //nolint:errorlint nonNilErrs = append(nonNilErrs, nested.errors...) } else { nonNilErrs = append(nonNilErrs, err) } } - return &multiError{errors: nonNilErrs, stack: callers(0)} + return &multiErr{errors: nonNilErrs, stack: callers(0)} } // Combine combines the passed errors into a single error. @@ -468,7 +455,9 @@ func Combine(errors ...error) error { // // Note that the variable MUST be a named return to append an error to it from // the defer statement. -func Append(left error, right error) error { +// +//nolint:errorlint +func Append(left, right error) error { switch { case left == nil && right == nil: return nil @@ -486,15 +475,15 @@ func Append(left error, right error) error { return left } - if _, ok := right.(*multiError); !ok { - if l, ok := left.(*multiError); ok && !l.copyNeeded.Swap(true) { + if _, ok := right.(*multiErr); !ok { + if l, ok := left.(*multiErr); ok && !l.copyNeeded.Swap(true) { // Common case where the error on the left is constantly being // appended to. - errs := append(l.errors, right) - return &multiError{errors: errs, stack: callers(0)} + errs := append(l.errors, right) //nolint:gocritic + return &multiErr{errors: errs, stack: callers(0)} } else if !ok { // Both errors are single errors. - return &multiError{errors: []error{left, right}, stack: callers(0)} + return &multiErr{errors: []error{left, right}, stack: callers(0)} } } @@ -505,7 +494,7 @@ func Append(left error, right error) error { } // Unwrap returns a list of errors wrapped by this multierr. -func (merr *multiError) Unwrap() []error { +func (merr *multiErr) Unwrap() []error { return merr.Errors() } @@ -518,9 +507,9 @@ func extractErrors(err error) []error { return nil } - // check if the given err is an Unwrapable error that + // check if the given err is an unwrappable error that // implements multipleErrorer interface. - eg, ok := err.(multipleErrorer) + eg, ok := err.(multipleErrorer) //nolint:errorlint if !ok { return []error{err} } @@ -528,11 +517,13 @@ func extractErrors(err error) []error { return append(([]error)(nil), eg.Unwrap()...) } -func IsMulti(err error) bool { - if err == nil { - return false +// Every compares every error in the given err against the given target error +// using [errors.Is], and returns true only if every comparison returned true. +func Every(err, target error) bool { + for _, e := range extractErrors(err) { + if !errors.Is(e, target) { + return false + } } - - _, ok := err.(*multiError) - return ok + return true } diff --git a/libsq/core/errz/multi_test.go b/libsq/core/errz/multi_test.go index 51026cffe..b3f90dc1e 100644 --- a/libsq/core/errz/multi_test.go +++ b/libsq/core/errz/multi_test.go @@ -2,8 +2,9 @@ package errz import ( "errors" - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestAppend_stdlib_errors(t *testing.T) { @@ -34,14 +35,13 @@ func TestAppend_stdlib_errors(t *testing.T) { errs = Errors(appendErr) require.Len(t, errs, 1) gotErr1 := errs[0] - _, ok := gotErr1.(*errz) + _, ok := gotErr1.(*errz) //nolint:errorlint require.True(t, ok) gotErr1Unwrap := errors.Unwrap(gotErr1) require.Equal(t, err1, gotErr1Unwrap) - } -func TestMultiErrors_errz(t *testing.T) { +func TestAppend_errz(t *testing.T) { err1 := New("err1") err2 := New("err2") errs := Errors(err1) @@ -50,8 +50,6 @@ func TestMultiErrors_errz(t *testing.T) { appendErr := Append(err1, err2) errs = Errors(appendErr) require.Equal(t, []error{err1, err2}, errs) - //t.Logf("%v", appendErr) - //t.Logf("%+v", appendErr) stacks := Stacks(appendErr) require.NotNil(t, stacks) diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index f13410982..343c830af 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -216,7 +216,7 @@ func callers(skip int) *stack { const depth = 32 var pcs [depth]uintptr n := runtime.Callers(3, pcs[:]) - //var st stack = pcs[0:n] + // var st stack = pcs[0:n] var st stack = pcs[skip:n] return &st } @@ -242,14 +242,14 @@ func Stacks(err error) []*StackTrace { var stacks []*StackTrace for err != nil { - if tracer, ok := err.(stackTracer); ok { + if tracer, ok := err.(stackTracer); ok { //nolint:errorlint st := tracer.stackTrace() if st != nil { stacks = append(stacks, st) } } - //err = errors.Unwrap(err) + // err = errors.Unwrap(err) err = errors.Unwrap(err) } @@ -269,13 +269,13 @@ func LastStack(err error) *StackTrace { } var ( - //var ez *errz + // var ez *errz ok bool tracer stackTracer inner error ) for err != nil { - tracer, ok = err.(stackTracer) + tracer, ok = err.(stackTracer) //nolint:errorlint if !ok || tracer == nil { return nil } @@ -285,6 +285,7 @@ func LastStack(err error) *StackTrace { return tracer.stackTrace() } + //nolint:errorlint if _, ok = inner.(stackTracer); !ok { return tracer.stackTrace() } From 20ed41e2e8ed05436884ad9a5af5f2fb1226c273 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 17 Dec 2023 06:58:19 -0700 Subject: [PATCH 145/195] errz cleanup --- libsq/core/errz/types.go | 1 - 1 file changed, 1 deletion(-) delete mode 100644 libsq/core/errz/types.go diff --git a/libsq/core/errz/types.go b/libsq/core/errz/types.go deleted file mode 100644 index f1ed88ea9..000000000 --- a/libsq/core/errz/types.go +++ /dev/null @@ -1 +0,0 @@ -package errz From 27543efd0fdd25347297f2642f30e044e35cda4f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 7 Jan 2024 18:19:08 -0700 Subject: [PATCH 146/195] wip: linting --- cli/cli.go | 3 --- cli/cmd_x.go | 26 ++++++++++--------- cli/diff/table.go | 1 + cli/error.go | 7 +++--- cli/output/jsonw/errorwriter.go | 5 ++-- cli/output/tablew/errorwriter.go | 6 ++--- cli/output/writers.go | 2 +- drivers/mysql/errors.go | 1 + drivers/postgres/errors.go | 1 + drivers/sqlite3/errors.go | 3 ++- drivers/sqlserver/errors.go | 1 + drivers/xlsx/ingest.go | 2 +- libsq/core/ioz/checksum/checksum.go | 5 ---- libsq/core/ioz/contextio/contextio.go | 2 +- libsq/core/ioz/download/README.md | 24 +++++++++--------- libsq/core/ioz/download/download.go | 2 ++ libsq/core/ioz/httpz/httpz.go | 6 ++--- libsq/core/ioz/httpz/httpz_test.go | 12 +++++---- libsq/core/ioz/httpz/opts.go | 25 +++++++++++++++++++ libsq/core/progress/bars.go | 36 +++++++++++++++++++++------ libsq/core/progress/style.go | 20 ++++++++++++++- libsq/core/stringz/stringz_test.go | 3 ++- testh/tu/tu.go | 6 +++-- 23 files changed, 135 insertions(+), 64 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index 5f00510fc..8222095aa 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -175,9 +175,6 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { err = rootCmd.ExecuteContext(ctx) lg.WarnIfCloseError(log, "Problem closing run", ru) if err != nil { - //ctx2 := rootCmd.Context() // FIXME: delete - //_ = ctx2 - printError(ctx, ru, err) } diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 265bffe7f..93a4864f7 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -3,12 +3,13 @@ package cli import ( "bufio" "fmt" - "github.com/neilotoole/sq/cli/flag" - "github.com/neilotoole/sq/libsq/core/errz" "net/url" "os" "time" + "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/run" @@ -175,8 +176,9 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { c := httpz.NewClient( httpz.DefaultUserAgent, - //httpz.OptRequestTimeout(time.Second*2), - httpz.OptHeaderTimeout(time.Millisecond), + httpz.OptRequestTimeout(time.Second*15), + // httpz.OptHeaderTimeout(time.Second*2), + httpz.OptRequestDelay(time.Second*5), ) dl, err := download.New(fakeSrc.Handle, c, u.String(), cacheDir) if err != nil { @@ -188,14 +190,14 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { switch { case len(h.Errors) > 0: - //err1 := errz.Err(h.Errors[0]) - //return err1 - - err1 := h.Errors[0] - err2 := errz.New("another err") - err3 := errz.Combine(err1, err2) - //lg.FromContext(ctx).Error("OH NO", lga.Err, err3) - return err3 + err1 := errz.Err(h.Errors[0]) + return err1 + + //err1 := h.Errors[0] + //err2 := errz.New("another err") + //err3 := errz.Combine(err1, err2) + ////lg.FromContext(ctx).Error("OH NO", lga.Err, err3) + //return err3 //return nil case len(h.WriteErrors) > 0: return h.WriteErrors[0] diff --git a/cli/diff/table.go b/cli/diff/table.go index 2dbf71b8e..d7d823893 100644 --- a/cli/diff/table.go +++ b/cli/diff/table.go @@ -3,6 +3,7 @@ package diff import ( "context" "fmt" + "github.com/neilotoole/sq/libsq/driver" "golang.org/x/sync/errgroup" diff --git a/cli/error.go b/cli/error.go index 44b2f1d90..71ee57d3a 100644 --- a/cli/error.go +++ b/cli/error.go @@ -4,11 +4,12 @@ import ( "context" "errors" "fmt" - "github.com/spf13/cobra" - "github.com/spf13/pflag" "os" "strings" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/output/format" "github.com/neilotoole/sq/cli/output/jsonw" @@ -37,7 +38,7 @@ func printError(ctx context.Context, ru *run.Run, err error) { return } - var cmdName = "unknown" + cmdName := "unknown" var cmd *cobra.Command if ru != nil && ru.Cmd != nil { cmd = ru.Cmd diff --git a/cli/output/jsonw/errorwriter.go b/cli/output/jsonw/errorwriter.go index 36ec3e03c..b98b672a4 100644 --- a/cli/output/jsonw/errorwriter.go +++ b/cli/output/jsonw/errorwriter.go @@ -40,7 +40,7 @@ type stack struct { } // Error implements output.ErrorWriter. -func (w *errorWriter) Error(systemErr error, humanErr error) { +func (w *errorWriter) Error(systemErr, humanErr error) { pr := w.pr.Clone() pr.String = pr.Warning @@ -68,7 +68,8 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { Error: &stackError{ Message: sysStack.Error.Error(), Tree: errz.SprintTreeTypes(sysStack.Error), - }} + }, + } ed.Stack = append(ed.Stack, st) } diff --git a/cli/output/tablew/errorwriter.go b/cli/output/tablew/errorwriter.go index 9df2d582a..131bf6d5e 100644 --- a/cli/output/tablew/errorwriter.go +++ b/cli/output/tablew/errorwriter.go @@ -23,7 +23,7 @@ func NewErrorWriter(w io.Writer, pr *output.Printing) output.ErrorWriter { } // Error implements output.ErrorWriter. -func (w *errorWriter) Error(systemErr error, humanErr error) { +func (w *errorWriter) Error(systemErr, humanErr error) { fmt.Fprintln(w.w, w.pr.Error.Sprintf("sq: %v", humanErr)) if !w.pr.Verbose { return @@ -34,7 +34,7 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { return } - var buf = &bytes.Buffer{} + buf := &bytes.Buffer{} var count int for _, stack := range stacks { if stack == nil { @@ -64,5 +64,5 @@ func (w *errorWriter) Error(systemErr error, humanErr error) { count++ } - buf.WriteTo(w.w) + _, _ = buf.WriteTo(w.w) } diff --git a/cli/output/writers.go b/cli/output/writers.go index 138b5ea6e..1579c81ae 100644 --- a/cli/output/writers.go +++ b/cli/output/writers.go @@ -100,7 +100,7 @@ type ErrorWriter interface { // Error outputs error conditions. It's possible that systemErr and // humanErr differ; systemErr is the error that occurred, and humanErr // is the error that should be presented to the user. - Error(systemErr error, humanErr error) + Error(systemErr, humanErr error) } // PingWriter writes ping results. diff --git a/drivers/mysql/errors.go b/drivers/mysql/errors.go index fdf629969..93e75659d 100644 --- a/drivers/mysql/errors.go +++ b/drivers/mysql/errors.go @@ -2,6 +2,7 @@ package mysql import ( "errors" + "github.com/neilotoole/sq/libsq/driver" "github.com/go-sql-driver/mysql" diff --git a/drivers/postgres/errors.go b/drivers/postgres/errors.go index ce278ed18..1baa1674c 100644 --- a/drivers/postgres/errors.go +++ b/drivers/postgres/errors.go @@ -2,6 +2,7 @@ package postgres import ( "errors" + "github.com/neilotoole/sq/libsq/driver" "github.com/jackc/pgx/v5/pgconn" diff --git a/drivers/sqlite3/errors.go b/drivers/sqlite3/errors.go index 60def3f6f..8ea505655 100644 --- a/drivers/sqlite3/errors.go +++ b/drivers/sqlite3/errors.go @@ -1,9 +1,10 @@ package sqlite3 import ( - "github.com/neilotoole/sq/libsq/driver" "strings" + "github.com/neilotoole/sq/libsq/driver" + "github.com/neilotoole/sq/libsq/core/errz" ) diff --git a/drivers/sqlserver/errors.go b/drivers/sqlserver/errors.go index 673ae40dc..50c174bd9 100644 --- a/drivers/sqlserver/errors.go +++ b/drivers/sqlserver/errors.go @@ -2,6 +2,7 @@ package sqlserver import ( "errors" + "github.com/neilotoole/sq/libsq/driver" mssql "github.com/microsoft/go-mssqldb" diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 717c0dc38..c68fa42ce 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -285,7 +285,7 @@ func buildSheetTables(ctx context.Context, srcIngestHeader *bool, sheets []*xShe sheetTbl, err := buildSheetTable(gCtx, srcIngestHeader, sheets[i]) if err != nil { if errz.Has[*driver.EmptyDataError](err) { - //if errz.IsErrNoData(err) { // FIXME: remove after testing + // if errz.IsErrNoData(err) { // FIXME: remove after testing // If the sheet has no data, we log it and skip it. lg.FromContext(ctx).Warn("Excel sheet has no data", laSheet, sheets[i].name, diff --git a/libsq/core/ioz/checksum/checksum.go b/libsq/core/ioz/checksum/checksum.go index 961b03218..1bfaf8b09 100644 --- a/libsq/core/ioz/checksum/checksum.go +++ b/libsq/core/ioz/checksum/checksum.go @@ -201,10 +201,5 @@ func ForHTTPResponse(resp *http.Response) Checksum { } } - s := buf.String() - _ = s - - fmt.Printf("\n\n%s\n\n", s) - return Checksum(Sum(buf.Bytes())) } diff --git a/libsq/core/ioz/contextio/contextio.go b/libsq/core/ioz/contextio/contextio.go index 82c37d428..b5555e241 100644 --- a/libsq/core/ioz/contextio/contextio.go +++ b/libsq/core/ioz/contextio/contextio.go @@ -215,7 +215,7 @@ func cause(ctx context.Context, err error) error { } // err is non-nil - if ctx.Err() != err { + if ctx.Err() != err { //nolint:errorlint // err is not the context error, so err takes precedence. return err } diff --git a/libsq/core/ioz/download/README.md b/libsq/core/ioz/download/README.md index dac847ca4..c2f313104 100644 --- a/libsq/core/ioz/download/README.md +++ b/libsq/core/ioz/download/README.md @@ -22,21 +22,21 @@ API-client and not for a shared proxy). - The built-in 'memory' cache stores responses in an in-memory map. - [`github.com/bitcomplete/httpcache/diskcache`](https://github.com/bitcomplete/httpcache/tree/master/diskcache) provides a filesystem-backed cache using the - [diskv](https://github.com/peterbourgon/diskv) library. - - [`github.com/bitcomplete/httpcache/memcache`](https://github.com/bitcomplete/httpcache/tree/master/memcache) + [diskv](https://github.com/peterbourgon/diskv) library. +- [`github.com/bitcomplete/httpcache/memcache`](https://github.com/bitcomplete/httpcache/tree/master/memcache) provides memcache implementations, for both App Engine and 'normal' memcache - servers. - - [`sourcegraph.com/sourcegraph/s3cache`](https://sourcegraph.com/github.com/sourcegraph/s3cache) - uses Amazon S3 for storage. - - [`github.com/bitcomplete/httpcache/leveldbcache`](https://github.com/bitcomplete/httpcache/tree/master/leveldbcache) + servers. +- [`sourcegraph.com/sourcegraph/s3cache`](https://sourcegraph.com/github.com/sourcegraph/s3cache) + uses Amazon S3 for storage. +- [`github.com/bitcomplete/httpcache/leveldbcache`](https://github.com/bitcomplete/httpcache/tree/master/leveldbcache) provides a filesystem-backed cache using - [leveldb](https://github.com/syndtr/goleveldb/leveldb). - - [`github.com/die-net/lrucache`](https://github.com/die-net/lrucache) provides an - in-memory cache that will evict least-recently used entries. - - [`github.com/die-net/lrucache/twotier`](https://github.com/die-net/lrucache/tree/master/twotier) + [leveldb](https://github.com/syndtr/goleveldb/leveldb). +- [`github.com/die-net/lrucache`](https://github.com/die-net/lrucache) provides an + in-memory cache that will evict least-recently used entries. +- [`github.com/die-net/lrucache/twotier`](https://github.com/die-net/lrucache/tree/master/twotier) allows caches to be combined, for example to use lrucache above with a - persistent disk-cache. - - [`github.com/birkelund/boltdbcache`](https://github.com/birkelund/boltdbcache) + persistent disk-cache. +- [`github.com/birkelund/boltdbcache`](https://github.com/birkelund/boltdbcache) provides a BoltDB implementation (based on the [bbolt](https://github.com/coreos/bbolt) fork). diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 2d5651a12..2a7aa5a04 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -302,7 +302,9 @@ func (dl *Download) get(req *http.Request, h Handler) { // do executes the request. func (dl *Download) do(req *http.Request) (*http.Response, error) { + bar := progress.FromContext(req.Context()).NewWaiter(dl.name+": start download", true) resp, err := dl.c.Do(req) + bar.Stop() if err != nil { // Download timeout errors are typically wrapped in an url.Error, resulting // in a message like: diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index 20e7730b1..1d8ef84de 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -51,13 +51,13 @@ func NewClient(opts ...Opt) *http.Client { } c.Transport = tr - c.Transport = RoundTrip(c.Transport, contextCause()) - - for i := range opts { + // Apply the round trip functions in reverse order. + for i := len(opts) - 1; i >= 0; i-- { if tf, ok := opts[i].(TripFunc); ok { c.Transport = RoundTrip(c.Transport, tf) } } + c.Transport = RoundTrip(c.Transport, contextCause()) return &c } diff --git a/libsq/core/ioz/httpz/httpz_test.go b/libsq/core/ioz/httpz/httpz_test.go index e1fe9f09c..146e000a1 100644 --- a/libsq/core/ioz/httpz/httpz_test.go +++ b/libsq/core/ioz/httpz/httpz_test.go @@ -20,6 +20,7 @@ import ( func TestOptRequestTimeout(t *testing.T) { t.Parallel() + const srvrBody = `Hello World!` serverDelay := time.Millisecond * 200 srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -34,16 +35,15 @@ func TestOptRequestTimeout(t *testing.T) { })) t.Cleanup(srvr.Close) - clientRequestTimeout := time.Millisecond * 100 - c := httpz.NewClient(httpz.OptRequestTimeout(clientRequestTimeout)) - req, err := http.NewRequest(http.MethodGet, srvr.URL, nil) + ctx := lg.NewContext(context.Background(), lgt.New(t)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvr.URL, nil) require.NoError(t, err) + clientRequestTimeout := time.Millisecond * 100 + c := httpz.NewClient(httpz.OptRequestTimeout(clientRequestTimeout)) resp, err := c.Do(req) - t.Log(err) require.Error(t, err) require.Nil(t, resp) - require.Contains(t, err.Error(), "http request not completed within") require.True(t, errors.Is(err, context.DeadlineExceeded)) } @@ -51,6 +51,7 @@ func TestOptRequestTimeout(t *testing.T) { // that fails via OptHeaderTimeout returns the correct error. func TestOptHeaderTimeout_correct_error(t *testing.T) { t.Parallel() + ctx := lg.NewContext(context.Background(), lgt.New(t)) const srvrBody = `Hello World!` @@ -93,6 +94,7 @@ func TestOptHeaderTimeout_correct_error(t *testing.T) { // works as expected when compared to stdlib. func TestOptHeaderTimeout_vs_stdlib(t *testing.T) { t.Parallel() + const ( headerTimeout = time.Millisecond * 200 numLines = 7 diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index 93c23795f..835da4f7e 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -197,3 +197,28 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { // DefaultHeaderTimeout is the default header timeout as used // by [NewDefaultClient]. var DefaultHeaderTimeout = OptHeaderTimeout(time.Second * 5) + +// OptRequestDelay is passed to [NewClient] to delay the request by the +// specified duration. This is useful for testing. +func OptRequestDelay(delay time.Duration) TripFunc { + if delay <= 0 { + return NopTripFunc + } + + return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { + ctx := req.Context() + log := lg.FromContext(ctx) + log.Debug("HTTP request delay: started", lga.Val, delay, lga.URL, req.URL.String()) + t := time.NewTimer(delay) + defer t.Stop() + select { + case <-ctx.Done(): + return nil, context.Cause(ctx) + case <-t.C: + // Continue below + } + + log.Debug("HTTP request delay: done", lga.Val, delay, lga.URL, req.URL.String()) + return next.RoundTrip(req) + } +} diff --git a/libsq/core/progress/bars.go b/libsq/core/progress/bars.go index e53dee90e..44501ecce 100644 --- a/libsq/core/progress/bars.go +++ b/libsq/core/progress/bars.go @@ -9,6 +9,9 @@ import ( "github.com/vbauerster/mpb/v8/decor" ) +// NewByteCounter returns a new determinate bar whose label +// metric is the size in bytes of the data being processed. The caller is +// ultimately responsible for calling [Bar.Stop] on the returned Bar. func (p *Progress) NewByteCounter(msg string, size int64) *Bar { if p == nil { return nil @@ -36,10 +39,7 @@ func (p *Progress) NewByteCounter(msg string, size int64) *Bar { // NewUnitCounter returns a new indeterminate bar whose label // metric is the plural of the provided unit. The caller is ultimately -// responsible for calling [Bar.Stop] on the returned Bar. However, -// the returned Bar is also added to the Progress's cleanup list, so -// it will be called automatically when the Progress is shut down, but that -// may be later than the actual conclusion of the spinner's work. +// responsible for calling [Bar.Stop] on the returned Bar. // // bar := p.NewUnitCounter("Ingest records", "rec") // defer bar.Stop() @@ -76,12 +76,32 @@ func (p *Progress) NewUnitCounter(msg, unit string) *Bar { return p.newBar(msg, -1, style, decorator) } +// NewWaiter returns a generic indeterminate spinner. If arg clock +// is true, a timer is shown. This produces output similar to: +// +// @excel/remote: start download ●∙∙ 4s +// +// The caller is ultimately responsible for calling [Bar.Stop] on the +// returned Bar. +func (p *Progress) NewWaiter(msg string, clock bool) *Bar { + if p == nil { + return nil + } + + p.mu.Lock() + defer p.mu.Unlock() + + var d []decor.Decorator + if clock { + d = append(d, newElapsedSeconds(p.colors.Size, time.Now(), decor.WCSyncSpace)) + } + style := spinnerStyle(p.colors.Filler) + return p.newBar(msg, -1, style, d...) +} + // NewUnitTotalCounter returns a new determinate bar whose label // metric is the plural of the provided unit. The caller is ultimately -// responsible for calling [Bar.Stop] on the returned Bar. However, -// the returned Bar is also added to the Progress's cleanup list, so -// it will be called automatically when the Progress is shut down, but that -// may be later than the actual conclusion of the Bar's work. +// responsible for calling [Bar.Stop] on the returned Bar. // // This produces output similar to: // diff --git a/libsq/core/progress/style.go b/libsq/core/progress/style.go index bbae96fde..5cd6fbf3b 100644 --- a/libsq/core/progress/style.go +++ b/libsq/core/progress/style.go @@ -9,12 +9,15 @@ import ( ) const ( - msgLength = 28 + msgLength = 36 barWidth = 28 boxWidth = 64 refreshRate = 150 * time.Millisecond ) +// @download_16b8a3b1: http start ∙●∙ +// @download_16b8a3b1: download ∙ 14.4 MiB / 427.6 MiB 3.4 + // DefaultColors returns the default colors used for the progress bars. func DefaultColors() *Colors { return &Colors{ @@ -96,3 +99,18 @@ func barStyle(c *color.Color) mpb.BarStyleComposer { Padding(" "). Tip(frames...).TipMeta(clr) } + +func newElapsedSeconds(c *color.Color, startTime time.Time, wcc ...decor.WC) decor.Decorator { + var msg string + producer := func(d time.Duration) string { + return " " + d.Round(time.Second).String() + } + fn := func(s decor.Statistics) string { + if !s.Completed { + msg = producer(time.Since(startTime)) + msg = c.Sprint(msg) + } + return msg + } + return decor.Any(fn, wcc...) +} diff --git a/libsq/core/stringz/stringz_test.go b/libsq/core/stringz/stringz_test.go index 4d5e617c6..021b41340 100644 --- a/libsq/core/stringz/stringz_test.go +++ b/libsq/core/stringz/stringz_test.go @@ -2,11 +2,12 @@ package stringz_test import ( "errors" - "github.com/neilotoole/sq/libsq/core/errz" "strconv" "strings" "testing" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/samber/lo" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" diff --git a/testh/tu/tu.go b/testh/tu/tu.go index e82df0821..5484a8ca0 100644 --- a/testh/tu/tu.go +++ b/testh/tu/tu.go @@ -419,7 +419,7 @@ func OpenFileCount(t testing.TB, log bool) int { count, out := doOpenFileCount(t) msg := fmt.Sprintf("Open files for [%d]: %d", os.Getpid(), count) if log { - msg += "\n\n" + string(out) + msg += "\n\n" + out } t.Log(msg) return count @@ -427,7 +427,9 @@ func OpenFileCount(t testing.TB, log bool) int { func doOpenFileCount(t testing.TB) (count int, out string) { SkipWindows(t, "OpenFileCount not implemented on Windows") - b, err := exec.Command("/bin/sh", "-c", fmt.Sprintf("lsof -p %v", os.Getpid())).Output() + + c := fmt.Sprintf("lsof -p %v", os.Getpid()) + b, err := exec.Command("/bin/sh", "-c", c).Output() require.NoError(t, err) lines := strings.Split(string(b), "\n") count = len(lines) - 1 From 696e021304473d3997ab45cd3d6c9d5bb8e0fe65 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 7 Jan 2024 19:09:07 -0700 Subject: [PATCH 147/195] wip: linting --- drivers/csv/csv.go | 36 ++++++++--------- libsq/core/ioz/download/cache.go | 6 +-- libsq/core/ioz/download/download.go | 41 +++++++++++--------- libsq/core/ioz/download/download_test.go | 49 +++++------------------- libsq/core/ioz/download/http.go | 14 ++++--- libsq/core/ioz/httpz/httpz.go | 6 +-- libsq/core/ioz/httpz/opts.go | 2 +- libsq/core/lg/devlog/devlog_test.go | 26 +------------ 8 files changed, 66 insertions(+), 114 deletions(-) diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index 994ed31b4..0b4bc6871 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -124,28 +124,28 @@ type grip struct { } // DB implements driver.Grip. -func (p *grip) DB(ctx context.Context) (*sql.DB, error) { - return p.impl.DB(ctx) +func (g *grip) DB(ctx context.Context) (*sql.DB, error) { + return g.impl.DB(ctx) } // SQLDriver implements driver.Grip. -func (p *grip) SQLDriver() driver.SQLDriver { - return p.impl.SQLDriver() +func (g *grip) SQLDriver() driver.SQLDriver { + return g.impl.SQLDriver() } // Source implements driver.Grip. -func (p *grip) Source() *source.Source { - return p.src +func (g *grip) Source() *source.Source { + return g.src } // TableMetadata implements driver.Grip. -func (p *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { if tblName != source.MonotableName { return nil, errz.Errorf("table name should be %s for CSV/TSV etc., but got: %s", source.MonotableName, tblName) } - srcMeta, err := p.SourceMetadata(ctx, false) + srcMeta, err := g.SourceMetadata(ctx, false) if err != nil { return nil, err } @@ -155,22 +155,22 @@ func (p *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Tab } // SourceMetadata implements driver.Grip. -func (p *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - md, err := p.impl.SourceMetadata(ctx, noSchema) +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + md, err := g.impl.SourceMetadata(ctx, noSchema) if err != nil { return nil, err } - md.Handle = p.src.Handle - md.Location = p.src.Location - md.Driver = p.src.Type + md.Handle = g.src.Handle + md.Location = g.src.Location + md.Driver = g.src.Type - md.Name, err = source.LocationFileName(p.src) + md.Name, err = source.LocationFileName(g.src) if err != nil { return nil, err } - md.Size, err = p.files.Filesize(ctx, p.src) + md.Size, err = g.files.Filesize(ctx, g.src) if err != nil { return nil, err } @@ -180,8 +180,8 @@ func (p *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou } // Close implements driver.Grip. -func (p *grip) Close() error { - p.log.Debug(lgm.CloseDB, lga.Handle, p.src.Handle) +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - return errz.Err(p.impl.Close()) + return errz.Err(g.impl.Close()) } diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index 458e839ce..945353dad 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -133,14 +133,14 @@ func (c *cache) get(ctx context.Context, req *http.Request) (*http.Response, err fpHeader, fpBody, _ := c.paths(req) if !ioz.FileAccessible(fpHeader) { // If the header file doesn't exist, it's a nil, nil situation. - return nil, nil + return nil, nil //nolint:nilnil } if _, ok := c.checksumsMatch(req); !ok { // If the checksums don't match, it's a nil, nil situation. // REVISIT: should we clear the cache here? - return nil, nil + return nil, nil //nolint:nilnil } headerBytes, err := os.ReadFile(fpHeader) @@ -207,7 +207,7 @@ func (c *cache) cachedChecksum(req *http.Request) (sum checksum.Checksum, ok boo // checksumsMatch returns true (and the valid checksum) if there is a cached // checksum file for req, and there is a cached response body file, and a fresh // checksum calculated from that body file matches the cached checksum. -func (c *cache) checksumsMatch(req *http.Request) (sum checksum.Checksum, ok bool) { +func (c *cache) checksumsMatch(req *http.Request) (sum checksum.Checksum, ok bool) { //nolint:unparam sum, ok = c.cachedChecksum(req) if !ok { return "", false diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 2a7aa5a04..4b015922d 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -54,7 +54,7 @@ const ( Transparent ) -// XFromCache is the header added to responses that are returned from the cache +// XFromCache is the header added to responses that are returned from the cache. const XFromCache = "X-From-Cache" const msgNilDestWriter = "nil dest writer from download handler; returning" @@ -141,7 +141,9 @@ func (dl *Download) Get(ctx context.Context, h Handler) { dl.get(req, h) } -func (dl *Download) get(req *http.Request, h Handler) { +// get contains the main logic for getting the download. It invokes Handler +// as appropriate. +func (dl *Download) get(req *http.Request, h Handler) { //nolint:funlen,gocognit ctx := req.Context() log := lg.FromContext(ctx) log.Debug("Get download", lga.URL, dl.url) @@ -149,15 +151,16 @@ func (dl *Download) get(req *http.Request, h Handler) { state := dl.state(req) if state == Fresh { + // The cached response is fresh, so we can return it. h.Cached(fpBody) return } - var err error cacheable := dl.isCacheable(req) + var err error var cachedResp *http.Response if cacheable { - cachedResp, err = dl.cache.get(req.Context(), req) + cachedResp, err = dl.cache.get(req.Context(), req) //nolint:bodyclose } else { // Need to invalidate an existing value if err = dl.cache.clear(req.Context()); err != nil { @@ -167,7 +170,7 @@ func (dl *Download) get(req *http.Request, h Handler) { } var resp *http.Response - if cacheable && cachedResp != nil && err == nil { + if cacheable && cachedResp != nil && err == nil { //nolint:nestif if dl.markCachedResponses { cachedResp.Header.Set(XFromCache, "1") } @@ -201,7 +204,7 @@ func (dl *Download) get(req *http.Request, h Handler) { } } - resp, err = dl.do(req) + resp, err = dl.do(req) //nolint:bodyclose if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { // Replace the 304 response with the one from cache, but update with some new headers endToEndHeaders := getEndToEndHeaders(resp.Header) @@ -210,8 +213,11 @@ func (dl *Download) get(req *http.Request, h Handler) { } lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) resp = cachedResp - } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) && - req.Method == http.MethodGet && canStaleOnError(cachedResp.Header, req.Header) { + } else if (err != nil || + (cachedResp != nil && + resp.StatusCode >= 500)) && + req.Method == http.MethodGet && + canStaleOnError(cachedResp.Header, req.Header) { // In case of transport failure and stale-if-error activated, returns cached content // when available log.Warn("Returning cached response due to transport failure", lga.Err, err) @@ -229,9 +235,9 @@ func (dl *Download) get(req *http.Request, h Handler) { } else { reqCacheControl := parseCacheControl(req.Header) if _, ok := reqCacheControl["only-if-cached"]; ok { - resp = newGatewayTimeoutResponse(req) + resp = newGatewayTimeoutResponse(req) //nolint:bodyclose } else { - resp, err = dl.do(req) + resp, err = dl.do(req) //nolint:bodyclose if err != nil { h.Error(err) return @@ -270,17 +276,19 @@ func (dl *Download) get(req *http.Request, h Handler) { defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) if err = dl.cache.write(req.Context(), resp, false, destWrtr); err != nil { log.Error("Failed to write download cache", lga.Dir, dl.cache.dir, lga.Err, err) - // destWrtr.Error(err) + // We don't need to explicitly call Handler.Error here, because the caller is + // informed via destWrtr.Error, which has already been invoked by cache.write. } return - } else { - lg.WarnIfError(log, "Delete resp cache", dl.cache.clear(req.Context())) } + lg.WarnIfError(log, "Delete resp cache", dl.cache.clear(req.Context())) + // It's not cacheable, so we need to write it to the destWrtr, // and skip the cache. destWrtr := h.Uncached() if destWrtr == nil { + // Shouldn't happen. log.Warn(msgNilDestWriter) return } @@ -296,8 +304,6 @@ func (dl *Download) get(req *http.Request, h Handler) { if err = destWrtr.Close(); err != nil { log.Error("Failed to close dest writer", lga.Err, err) } - - return } // do executes the request. @@ -323,7 +329,7 @@ func (dl *Download) do(req *http.Request) (*http.Response, error) { if resp.Body != nil && resp.Body != http.NoBody { r := progress.NewReader(req.Context(), dl.name+": download", resp.ContentLength, resp.Body) - resp.Body = r.(io.ReadCloser) + resp.Body, _ = r.(io.ReadCloser) } return resp, nil } @@ -335,7 +341,6 @@ func (dl *Download) mustRequest(ctx context.Context) *http.Request { if err != nil { lg.FromContext(ctx).Error("Failed to create request", lga.URL, dl.url, lga.Err, err) panic(err) - return nil } return req } @@ -374,7 +379,7 @@ func (dl *Download) state(req *http.Request) State { defer lg.WarnIfCloseError(log, msgCloseCacheHeaderFile, f) - cachedResp, err := httpz.ReadResponseHeader(bufio.NewReader(f), nil) + cachedResp, err := httpz.ReadResponseHeader(bufio.NewReader(f), nil) //nolint:bodyclose if err != nil { log.Error("Failed to read cached response header", lga.Err, err) return Uncached diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index 26d7e866e..a0d7625df 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -5,7 +5,6 @@ import ( "net/http" "net/http/httptest" "os" - "path/filepath" "strconv" "testing" "time" @@ -60,25 +59,23 @@ func TestDownload_redirect(t *testing.T) { const hello = `Hello World!` serveBody := hello lastModified := time.Now().UTC() - // cacheDir := t.TempDir() - // FIXME: switch back to temp dir - cacheDir := filepath.Join("testdata", "download", tu.Name(t.Name())) + cacheDir := t.TempDir() log := lgt.New(t) var srvr *httptest.Server srvr = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log := log.With("origin", "server") - log.Info("Request on /actual", "req", httpz.RequestLogValue(r)) + srvrLog := log.With("origin", "server") + srvrLog.Info("Request on /actual", "req", httpz.RequestLogValue(r)) switch r.URL.Path { case "/redirect": loc := srvr.URL + "/actual" - log.Info("Redirecting to", lga.Loc, loc) + srvrLog.Info("Redirecting to", lga.Loc, loc) http.Redirect(w, r, loc, http.StatusFound) case "/actual": if ifm := r.Header.Get("If-Modified-Since"); ifm != "" { tm, err := time.Parse(http.TimeFormat, ifm) if err != nil { - log.Error("Failed to parse If-Modified-Since", lga.Err, err) + srvrLog.Error("Failed to parse If-Modified-Since", lga.Err, err) w.WriteHeader(http.StatusBadRequest) return } @@ -87,19 +84,19 @@ func TestDownload_redirect(t *testing.T) { lastModifiedUnix := lastModified.Unix() if lastModifiedUnix <= ifModifiedSinceUnix { - log.Info("Serving http.StatusNotModified") + srvrLog.Info("Serving http.StatusNotModified") w.WriteHeader(http.StatusNotModified) return } } - log.Info("Serving actual: writing bytes") + srvrLog.Info("Serving actual: writing bytes") b := []byte(serveBody) w.Header().Set("Last-Modified", lastModified.Format(http.TimeFormat)) _, err := w.Write(b) assert.NoError(t, err) default: - log.Info("Serving http.StatusNotFound") + srvrLog.Info("Serving http.StatusNotFound") w.WriteHeader(http.StatusNotFound) } })) @@ -139,40 +136,12 @@ func TestDownload_redirect(t *testing.T) { require.Equal(t, serveBody, gotBody) } -//tr := httpcache.NewTransport(diskcache.New(cacheDir)) -//req, err := http.NewRequestWithContext(ctx, http.MethodGet, loc, nil) -//require.NoError(t, err) -// -//resp, err := tr.RoundTrip(req) -//require.NoError(t, err) -//require.Equal(t, http.StatusOK, resp.StatusCode) -//b, err := io.ReadAll(resp.Body) -//require.NoError(t, err) -//require.Equal(t, serveBody, string(b)) -//t.Logf("b: \n\n%s\n\n", b) -// -//resp2, err := tr.RoundTrip(req) -//require.NoError(t, err) -//require.Equal(t, http.StatusOK, resp2.StatusCode) -// -//b, err = io.ReadAll(resp.Body) -//require.NoError(t, err) -//require.Equal(t, serveBody, string(b)) -//t.Logf("b: \n\n%s\n\n", b) - -// -//ctx := lg.NewContext(context.Background(), log.With("origin", "downloader")) -//loc := srvr.URL + "/redirect" -//loc := srvr.URL + "/actual" - func TestDownload_New(t *testing.T) { log := lgt.New(t) ctx := lg.NewContext(context.Background(), log) const dlURL = urlActorCSV - // FIXME: switch to temp dir - cacheDir, err := filepath.Abs(filepath.Join("testdata", "download", tu.Name(t.Name()))) - require.NoError(t, err) + cacheDir := t.TempDir() t.Logf("cacheDir: %s", cacheDir) dl, err := download.New(t.Name(), httpz.NewDefaultClient(), dlURL, cacheDir) diff --git a/libsq/core/ioz/download/http.go b/libsq/core/ioz/download/http.go index 899e055cd..98da338fa 100644 --- a/libsq/core/ioz/download/http.go +++ b/libsq/core/ioz/download/http.go @@ -17,7 +17,7 @@ func getDate(respHeaders http.Header) (date time.Time, err error) { dateHeader := respHeaders.Get("date") if dateHeader == "" { err = errNoDateHeader - return + return date, err } return time.Parse(time.RFC1123, dateHeader) @@ -76,7 +76,8 @@ func getFreshness(respHeaders, reqHeaders http.Header) (freshness State) { } else { expiresHeader := respHeaders.Get("Expires") if expiresHeader != "" { - expires, err := time.Parse(time.RFC1123, expiresHeader) + var expires time.Time + expires, err = time.Parse(time.RFC1123, expiresHeader) if err != nil { lifetime = zeroDuration } else { @@ -94,9 +95,10 @@ func getFreshness(respHeaders, reqHeaders http.Header) (freshness State) { } if minFresh, ok := reqCacheControl["min-fresh"]; ok { // the client wants a response that will still be fresh for at least the specified number of seconds. - minFreshDuration, err := time.ParseDuration(minFresh + "s") + var minFreshDuration time.Duration + minFreshDuration, err = time.ParseDuration(minFresh + "s") if err == nil { - currentAge = time.Duration(currentAge + minFreshDuration) + currentAge += minFreshDuration } } @@ -114,7 +116,7 @@ func getFreshness(respHeaders, reqHeaders http.Header) (freshness State) { } maxStaleDuration, err := time.ParseDuration(maxStale + "s") if err == nil { - currentAge = time.Duration(currentAge - maxStaleDuration) + currentAge -= maxStaleDuration } } @@ -274,7 +276,7 @@ func headerAllCommaSepValues(headers http.Header, name string) []string { } // varyMatches will return false unless all the cached values for the -// headers listed in Vary match the new request +// headers listed in Vary match the new request. func varyMatches(cachedResp *http.Response, req *http.Request) bool { for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") { header = http.CanonicalHeaderKey(header) diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index 1d8ef84de..cece6f3bd 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -67,7 +67,7 @@ var _ Opt = (*TripFunc)(nil) // It is commonly used with RoundTrip to decorate an existing http.RoundTripper. type TripFunc func(next http.RoundTripper, req *http.Request) (*http.Response, error) -func (tf TripFunc) apply(tr *http.Transport) {} +func (tf TripFunc) apply(*http.Transport) {} // RoundTrip adapts a TripFunc to http.RoundTripper. func RoundTrip(next http.RoundTripper, fn TripFunc) http.RoundTripper { @@ -188,7 +188,7 @@ func ReadResponseHeader(r *bufio.Reader, req *http.Request) (resp *http.Response // Parse the first line of the response. line, err := tp.ReadLine() if err != nil { - if err == io.EOF { + if err == io.EOF { //nolint:errorlint err = io.ErrUnexpectedEOF } return nil, err @@ -215,7 +215,7 @@ func ReadResponseHeader(r *bufio.Reader, req *http.Request) (resp *http.Response // Parse the response headers. mimeHeader, err := tp.ReadMIMEHeader() if err != nil { - if err == io.EOF { + if err == io.EOF { //nolint:errorlint err = io.ErrUnexpectedEOF } return nil, err diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index 835da4f7e..9f042f8e8 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -43,7 +43,7 @@ func (v minTLSVersion) apply(tr *http.Transport) { tr.TLSClientConfig = &tls.Config{MinVersion: uint16(v)} //nolint:gosec } else { tr.TLSClientConfig = tr.TLSClientConfig.Clone() - tr.TLSClientConfig.MinVersion = uint16(v) //nolint:gosec + tr.TLSClientConfig.MinVersion = uint16(v) } } diff --git a/libsq/core/lg/devlog/devlog_test.go b/libsq/core/lg/devlog/devlog_test.go index 0e380e6e0..b47c1a668 100644 --- a/libsq/core/lg/devlog/devlog_test.go +++ b/libsq/core/lg/devlog/devlog_test.go @@ -12,39 +12,15 @@ import ( func TestDevlog(t *testing.T) { log := lgt.New(t) - // log.Debug("huzzah") err := errz.New("oh noes") - // stack := errs.Stacks(err) - // lga.Stack, errz.Stacks(err) log.Error("bah", lga.Err, err) } func TestDevlogTextHandler(t *testing.T) { - o := &slog.HandlerOptions{ - ReplaceAttr: ReplaceAttr, - } + o := &slog.HandlerOptions{} h := slog.NewTextHandler(os.Stdout, o) log := slog.New(h) - // log := lgt.New(t) - // log.Debug("huzzah") err := errz.New("oh noes") - // stack := errs.Stacks(err) - // lga.Stack, errz.Stacks(err) log.Error("bah", lga.Err, err) } - -func ReplaceAttr(groups []string, a slog.Attr) slog.Attr { - switch a.Key { - case "pid": - return slog.Attr{} - case "error": - if _, ok := a.Value.Any().(error); ok { - a.Key = "e" - } - a.Key = "wussah" - return a - default: - return a - } -} From 0395788b48b0b784c10bd6fa044033dfbec2bb2d Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 7 Jan 2024 22:14:00 -0700 Subject: [PATCH 148/195] download seems to be working --- cli/cmd_add.go | 44 +- cli/options.go | 2 +- cli/run.go | 10 +- libsq/core/ioz/download/httpcache_test.go | 1471 --------------------- libsq/core/lg/devlog/devlog_test.go | 26 - libsq/driver/grips.go | 10 +- libsq/source/cache.go | 7 + libsq/source/detect.go | 10 +- libsq/source/download.go | 12 +- libsq/source/files.go | 150 ++- libsq/source/files_test.go | 29 +- libsq/source/internal_test.go | 11 +- testh/testh.go | 13 +- testh/testh_test.go | 2 +- 14 files changed, 205 insertions(+), 1592 deletions(-) delete mode 100644 libsq/core/ioz/download/httpcache_test.go delete mode 100644 libsq/core/lg/devlog/devlog_test.go diff --git a/cli/cmd_add.go b/cli/cmd_add.go index 6a3416a8f..e5e5e0e68 100644 --- a/cli/cmd_add.go +++ b/cli/cmd_add.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "github.com/neilotoole/sq/libsq/core/stringz" "io" "os" "strings" @@ -19,7 +20,6 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" ) @@ -186,23 +186,6 @@ func execSrcAdd(cmd *cobra.Command, args []string) error { var err error var typ drivertype.Type - if cmdFlagChanged(cmd, flag.AddDriver) { - val, _ := cmd.Flags().GetString(flag.AddDriver) - typ = drivertype.Type(strings.TrimSpace(val)) - } else { - typ, err = ru.Files.DriverType(ctx, loc) - if err != nil { - return err - } - if typ == drivertype.None { - return errz.Errorf("unable to determine driver type: use --driver flag") - } - } - - if ru.DriverRegistry.ProviderFor(typ) == nil { - return errz.Errorf("unsupported driver type {%s}", typ) - } - var handle string if cmdFlagChanged(cmd, flag.Handle) { handle, _ = cmd.Flags().GetString(flag.Handle) @@ -213,18 +196,35 @@ func execSrcAdd(cmd *cobra.Command, args []string) error { } } - if stringz.InSlice(source.ReservedHandles(), handle) { - return errz.Errorf("handle reserved for system use: %s", handle) - } - if err = source.ValidHandle(handle); err != nil { return err } + if stringz.InSlice(source.ReservedHandles(), handle) { + return errz.Errorf("handle reserved for system use: %s", handle) + } + if cfg.Collection.IsExistingSource(handle) { return errz.Errorf("source handle already exists: %s", handle) } + if cmdFlagChanged(cmd, flag.AddDriver) { + val, _ := cmd.Flags().GetString(flag.AddDriver) + typ = drivertype.Type(strings.TrimSpace(val)) + } else { + typ, err = ru.Files.DriverType(ctx, handle, loc) + if err != nil { + return err + } + if typ == drivertype.None { + return errz.Errorf("unable to determine driver type: use --driver flag") + } + } + + if ru.DriverRegistry.ProviderFor(typ) == nil { + return errz.Errorf("unsupported driver type {%s}", typ) + } + if typ == sqlite3.Type { locBefore := loc // Special handling for SQLite, because it's a file-based DB. diff --git a/cli/options.go b/cli/options.go index 5cef01174..161780df6 100644 --- a/cli/options.go +++ b/cli/options.go @@ -176,7 +176,7 @@ func RegisterDefaultOpts(reg *options.Registry) { OptDiffNumLines, OptDiffDataFormat, source.OptHTTPPingTimeout, - source.OptHTTPSkipVerify, + source.OptHTTPSInsecureSkipVerify, driver.OptConnMaxOpen, driver.OptConnMaxIdle, driver.OptConnMaxIdleTime, diff --git a/cli/run.go b/cli/run.go index 3554981b5..0dce98c49 100644 --- a/cli/run.go +++ b/cli/run.go @@ -143,8 +143,14 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { var err error if ru.Files == nil { - ru.Files, err = source.NewFiles(ctx, ru.OptionsRegistry, - source.DefaultTempDir(), source.DefaultCacheDir(), true) + ru.Files, err = source.NewFiles( + ctx, + ru.Config.Collection, + ru.OptionsRegistry, + source.DefaultTempDir(), + source.DefaultCacheDir(), + true, + ) if err != nil { lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run) return err diff --git a/libsq/core/ioz/download/httpcache_test.go b/libsq/core/ioz/download/httpcache_test.go deleted file mode 100644 index 05f1ba944..000000000 --- a/libsq/core/ioz/download/httpcache_test.go +++ /dev/null @@ -1,1471 +0,0 @@ -package download - -// -//// newTestTransport returns a new Download using the in-memory cache implementation -//func newTestTransport(cacheDir string, opts ...Opt) *Download { -// t := New(cacheDir, opts...) -// return t -//} -// -//var s struct { -// server *httptest.Server -// client http.Client -// transport *Download -// done chan struct{} // Closed to unlock infinite handlers. -//} -// -//type fakeClock struct { -// elapsed time.Duration -//} -// -//func (c *fakeClock) since(t time.Time) time.Duration { -// return c.elapsed -//} -// -//func TestMain(m *testing.M) { -// flag.Parse() -// setup() -// code := m.Run() -// teardown() -// os.Exit(code) -//} -// -//func setup() { -// tp := newTestTransport(filepath.Join(os.TempDir(), stringz.Uniq8())) -// client := http.Client{Download: tp} -// s.transport = tp -// s.client = client -// s.done = make(chan struct{}) -// -// mux := http.NewServeMux() -// s.server = httptest.NewServer(mux) -// -// mux.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// w.Header().Set("Cache-Control", "max-age=3600") -// })) -// -// mux.HandleFunc("/method", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// w.Header().Set("Cache-Control", "max-age=3600") -// _, _ = w.Write([]byte(r.Method)) -// })) -// -// mux.HandleFunc("/range", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// lm := "Fri, 14 Dec 2010 01:01:50 GMT" -// if r.Header.Get("if-modified-since") == lm { -// w.WriteHeader(http.StatusNotModified) -// return -// } -// w.Header().Set("last-modified", lm) -// if r.Header.Get("range") == "bytes=4-9" { -// w.WriteHeader(http.StatusPartialContent) -// _, _ = w.Write([]byte(" text ")) -// return -// } -// _, _ = w.Write([]byte("Some text content")) -// })) -// -// mux.HandleFunc("/nostore", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// w.Header().Set("Cache-Control", "no-store") -// })) -// -// mux.HandleFunc("/etag", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// etag := "124567" -// if r.Header.Get("if-none-match") == etag { -// w.WriteHeader(http.StatusNotModified) -// return -// } -// w.Header().Set("etag", etag) -// })) -// -// mux.HandleFunc("/lastmodified", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// lm := "Fri, 14 Dec 2010 01:01:50 GMT" -// if r.Header.Get("if-modified-since") == lm { -// w.WriteHeader(http.StatusNotModified) -// return -// } -// w.Header().Set("last-modified", lm) -// })) -// -// mux.HandleFunc("/varyaccept", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// w.Header().Set("Cache-Control", "max-age=3600") -// w.Header().Set("Content-Type", "text/plain") -// w.Header().Set("Vary", "Accept") -// _, _ = w.Write([]byte("Some text content")) -// })) -// -// mux.HandleFunc("/doublevary", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// w.Header().Set("Cache-Control", "max-age=3600") -// w.Header().Set("Content-Type", "text/plain") -// w.Header().Set("Vary", "Accept, Accept-Language") -// _, _ = w.Write([]byte("Some text content")) -// })) -// mux.HandleFunc("/2varyheaders", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// w.Header().Set("Cache-Control", "max-age=3600") -// w.Header().Set("Content-Type", "text/plain") -// w.Header().Add("Vary", "Accept") -// w.Header().Add("Vary", "Accept-Language") -// _, _ = w.Write([]byte("Some text content")) -// })) -// mux.HandleFunc("/varyunused", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// w.Header().Set("Cache-Control", "max-age=3600") -// w.Header().Set("Content-Type", "text/plain") -// w.Header().Set("Vary", "X-Madeup-Header") -// _, _ = w.Write([]byte("Some text content")) -// })) -// -// mux.HandleFunc("/cachederror", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// etag := "abc" -// if r.Header.Get("if-none-match") == etag { -// w.WriteHeader(http.StatusNotModified) -// return -// } -// w.Header().Set("etag", etag) -// w.WriteHeader(http.StatusNotFound) -// _, _ = w.Write([]byte("Not found")) -// })) -// -// updateFieldsCounter := 0 -// mux.HandleFunc("/updatefields", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// w.Header().Set("X-Counter", strconv.Itoa(updateFieldsCounter)) -// w.Header().Set("Etag", `"e"`) -// updateFieldsCounter++ -// if r.Header.Get("if-none-match") != "" { -// w.WriteHeader(http.StatusNotModified) -// return -// } -// _, _ = w.Write([]byte("Some text content")) -// })) -// -// // Take 3 seconds to return 200 OK (for testing client timeouts). -// mux.HandleFunc("/3seconds", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// time.Sleep(3 * time.Second) -// })) -// -// mux.HandleFunc("/infinite", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// for { -// select { -// case <-s.done: -// return -// default: -// _, _ = w.Write([]byte{0}) -// } -// } -// })) -//} -// -//func teardown() { -// close(s.done) -// s.server.Close() -//} -// -//func resetTest(t testing.TB) { -// s.transport.Cache = NewRespCache(t.TempDir()) -// //s.transport.Cache.Delete() -// clock = &realClock{} -//} -// -//// TestCacheableMethod ensures that uncacheable method does not get stored -//// in cache and get incorrectly used for a following cacheable method request. -//func TestCacheableMethod(t *testing.T) { -// resetTest(t) -// { -// req, err := http.NewRequest("POST", s.server.URL+"/method", nil) -// if err != nil { -// t.Fatal(err) -// } -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// var buf bytes.Buffer -// _, err = io.Copy(&buf, resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// err = resp.Body.Close() -// if err != nil { -// t.Fatal(err) -// } -// if got, want := buf.String(), "POST"; got != want { -// t.Errorf("got %q, want %q", got, want) -// } -// if resp.StatusCode != http.StatusOK { -// t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) -// } -// } -// { -// req, err := http.NewRequest("GET", s.server.URL+"/method", nil) -// if err != nil { -// t.Fatal(err) -// } -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// var buf bytes.Buffer -// _, err = io.Copy(&buf, resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// err = resp.Body.Close() -// if err != nil { -// t.Fatal(err) -// } -// if got, want := buf.String(), "GET"; got != want { -// t.Errorf("got wrong body %q, want %q", got, want) -// } -// if resp.StatusCode != http.StatusOK { -// t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) -// } -// if resp.Header.Get(XFromCache) != "" { -// t.Errorf("XFromCache header isn't blank") -// } -// } -//} -// -//func TestDontServeHeadResponseToGetRequest(t *testing.T) { -// resetTest(t) -// url := s.server.URL + "/" -// req, err := http.NewRequest(http.MethodHead, url, nil) -// if err != nil { -// t.Fatal(err) -// } -// _, err = s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// req, err = http.NewRequest(http.MethodGet, url, nil) -// if err != nil { -// t.Fatal(err) -// } -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// if resp.Header.Get(XFromCache) != "" { -// t.Errorf("Cache should not match") -// } -//} -// -//func TestDontStorePartialRangeInCache(t *testing.T) { -// resetTest(t) -// { -// req, err := http.NewRequest("GET", s.server.URL+"/range", nil) -// if err != nil { -// t.Fatal(err) -// } -// req.Header.Set("range", "bytes=4-9") -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// var buf bytes.Buffer -// _, err = io.Copy(&buf, resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// err = resp.Body.Close() -// if err != nil { -// t.Fatal(err) -// } -// if got, want := buf.String(), " text "; got != want { -// t.Errorf("got %q, want %q", got, want) -// } -// if resp.StatusCode != http.StatusPartialContent { -// t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) -// } -// } -// { -// req, err := http.NewRequest("GET", s.server.URL+"/range", nil) -// if err != nil { -// t.Fatal(err) -// } -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// var buf bytes.Buffer -// _, err = io.Copy(&buf, resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// err = resp.Body.Close() -// if err != nil { -// t.Fatal(err) -// } -// if got, want := buf.String(), "Some text content"; got != want { -// t.Errorf("got %q, want %q", got, want) -// } -// if resp.StatusCode != http.StatusOK { -// t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) -// } -// if resp.Header.Get(XFromCache) != "" { -// t.Error("XFromCache header isn't blank") -// } -// } -// { -// req, err := http.NewRequest("GET", s.server.URL+"/range", nil) -// if err != nil { -// t.Fatal(err) -// } -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// var buf bytes.Buffer -// _, err = io.Copy(&buf, resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// err = resp.Body.Close() -// if err != nil { -// t.Fatal(err) -// } -// if got, want := buf.String(), "Some text content"; got != want { -// t.Errorf("got %q, want %q", got, want) -// } -// if resp.StatusCode != http.StatusOK { -// t.Errorf("response status code isn't 200 OK: %v", resp.StatusCode) -// } -// if resp.Header.Get(XFromCache) != "1" { -// t.Errorf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) -// } -// } -// { -// req, err := http.NewRequest("GET", s.server.URL+"/range", nil) -// if err != nil { -// t.Fatal(err) -// } -// req.Header.Set("range", "bytes=4-9") -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// var buf bytes.Buffer -// _, err = io.Copy(&buf, resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// err = resp.Body.Close() -// if err != nil { -// t.Fatal(err) -// } -// if got, want := buf.String(), " text "; got != want { -// t.Errorf("got %q, want %q", got, want) -// } -// if resp.StatusCode != http.StatusPartialContent { -// t.Errorf("response status code isn't 206 Partial Content: %v", resp.StatusCode) -// } -// } -//} -// -//func TestCacheOnlyIfBodyRead(t *testing.T) { -// resetTest(t) -// { -// req, err := http.NewRequest("GET", s.server.URL, nil) -// if err != nil { -// t.Fatal(err) -// } -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// // We do not read the body -// resp.Body.Close() -// } -// { -// req, err := http.NewRequest("GET", s.server.URL, nil) -// if err != nil { -// t.Fatal(err) -// } -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatalf("XFromCache header isn't blank") -// } -// } -//} -// -//func TestOnlyReadBodyOnDemand(t *testing.T) { -// resetTest(t) -// -// req, err := http.NewRequest("GET", s.server.URL+"/infinite", nil) -// if err != nil { -// t.Fatal(err) -// } -// resp, err := s.client.Do(req) // This shouldn't hang forever. -// if err != nil { -// t.Fatal(err) -// } -// buf := make([]byte, 10) // Only partially read the body. -// _, err = resp.Body.Read(buf) -// if err != nil { -// t.Fatal(err) -// } -// resp.Body.Close() -//} -// -//func TestGetOnlyIfCachedHit(t *testing.T) { -// resetTest(t) -// { -// req, err := http.NewRequest("GET", s.server.URL, nil) -// if err != nil { -// t.Fatal(err) -// } -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// } -// { -// req, err := http.NewRequest("GET", s.server.URL, nil) -// if err != nil { -// t.Fatal(err) -// } -// req.Header.Add("cache-control", "only-if-cached") -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "1" { -// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) -// } -// if resp.StatusCode != http.StatusOK { -// t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) -// } -// } -//} -// -//func TestGetOnlyIfCachedMiss(t *testing.T) { -// resetTest(t) -// req, err := http.NewRequest("GET", s.server.URL, nil) -// if err != nil { -// t.Fatal(err) -// } -// req.Header.Add("cache-control", "only-if-cached") -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// if resp.StatusCode != http.StatusGatewayTimeout { -// t.Fatalf("response status code isn't 504 GatewayTimeout: %v", resp.StatusCode) -// } -//} -// -//func TestGetNoStoreRequest(t *testing.T) { -// resetTest(t) -// req, err := http.NewRequest("GET", s.server.URL, nil) -// if err != nil { -// t.Fatal(err) -// } -// req.Header.Add("Cache-Control", "no-store") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -//} -// -//func TestGetNoStoreResponse(t *testing.T) { -// resetTest(t) -// req, err := http.NewRequest("GET", s.server.URL+"/nostore", nil) -// if err != nil { -// t.Fatal(err) -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -//} -// -//func TestGetWithEtag(t *testing.T) { -// resetTest(t) -// req, err := http.NewRequest("GET", s.server.URL+"/etag", nil) -// if err != nil { -// t.Fatal(err) -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "1" { -// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) -// } -// // additional assertions to verify that 304 response is converted properly -// if resp.StatusCode != http.StatusOK { -// t.Fatalf("response status code isn't 200 OK: %v", resp.StatusCode) -// } -// if _, ok := resp.Header["Connection"]; ok { -// t.Fatalf("Connection header isn't absent") -// } -// } -//} -// -//func TestGetWithLastModified(t *testing.T) { -// resetTest(t) -// req, err := http.NewRequest("GET", s.server.URL+"/lastmodified", nil) -// if err != nil { -// t.Fatal(err) -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "1" { -// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) -// } -// } -//} -// -//func TestGetWithVary(t *testing.T) { -// resetTest(t) -// req, err := http.NewRequest("GET", s.server.URL+"/varyaccept", nil) -// if err != nil { -// t.Fatal(err) -// } -// req.Header.Set("Accept", "text/plain") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get("Vary") != "Accept" { -// t.Fatalf(`Vary header isn't "Accept": %v`, resp.Header.Get("Vary")) -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "1" { -// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) -// } -// } -// req.Header.Set("Accept", "text/html") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -// req.Header.Set("Accept", "") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -//} -// -//func TestGetWithDoubleVary(t *testing.T) { -// resetTest(t) -// req, err := http.NewRequest("GET", s.server.URL+"/doublevary", nil) -// if err != nil { -// t.Fatal(err) -// } -// req.Header.Set("Accept", "text/plain") -// req.Header.Set("Accept-Language", "da, en-gb;q=0.8, en;q=0.7") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get("Vary") == "" { -// t.Fatalf(`Vary header is blank`) -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "1" { -// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) -// } -// } -// req.Header.Set("Accept-Language", "") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -// req.Header.Set("Accept-Language", "da") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -//} -// -//func TestGetWith2VaryHeaders(t *testing.T) { -// resetTest(t) -// // Tests that multiple Vary headers' comma-separated lists are -// // merged. See https://github.com/gregjones/httpcache/issues/27. -// const ( -// accept = "text/plain" -// acceptLanguage = "da, en-gb;q=0.8, en;q=0.7" -// ) -// req, err := http.NewRequest("GET", s.server.URL+"/2varyheaders", nil) -// if err != nil { -// t.Fatal(err) -// } -// req.Header.Set("Accept", accept) -// req.Header.Set("Accept-Language", acceptLanguage) -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get("Vary") == "" { -// t.Fatalf(`Vary header is blank`) -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "1" { -// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) -// } -// } -// req.Header.Set("Accept-Language", "") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -// req.Header.Set("Accept-Language", "da") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -// req.Header.Set("Accept-Language", acceptLanguage) -// req.Header.Set("Accept", "") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// } -// req.Header.Set("Accept", "image/png") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "" { -// t.Fatal("XFromCache header isn't blank") -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "1" { -// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) -// } -// } -//} -// -//func TestGetVaryUnused(t *testing.T) { -// resetTest(t) -// req, err := http.NewRequest("GET", s.server.URL+"/varyunused", nil) -// if err != nil { -// t.Fatal(err) -// } -// req.Header.Set("Accept", "text/plain") -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get("Vary") == "" { -// t.Fatalf(`Vary header is blank`) -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "1" { -// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) -// } -// } -//} -// -//func TestUpdateFields(t *testing.T) { -// resetTest(t) -// req, err := http.NewRequest("GET", s.server.URL+"/updatefields", nil) -// if err != nil { -// t.Fatal(err) -// } -// var counter, counter2 string -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// counter = resp.Header.Get("x-counter") -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.Header.Get(XFromCache) != "1" { -// t.Fatalf(`XFromCache header isn't "1": %v`, resp.Header.Get(XFromCache)) -// } -// counter2 = resp.Header.Get("x-counter") -// } -// if counter == counter2 { -// t.Fatalf(`both "x-counter" values are equal: %v %v`, counter, counter2) -// } -//} -// -//// This tests the fix for https://github.com/gregjones/httpcache/issues/74. -//// Previously, after validating a cached response, its StatusCode -//// was incorrectly being replaced. -//func TestCachedErrorsKeepStatus(t *testing.T) { -// resetTest(t) -// req, err := http.NewRequest("GET", s.server.URL+"/cachederror", nil) -// if err != nil { -// t.Fatal(err) -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// _, _ = io.Copy(ioutil.Discard, resp.Body) -// } -// { -// resp, err := s.client.Do(req) -// if err != nil { -// t.Fatal(err) -// } -// defer resp.Body.Close() -// if resp.StatusCode != http.StatusNotFound { -// t.Fatalf("Status code isn't 404: %d", resp.StatusCode) -// } -// } -//} -// -//func TestParseCacheControl(t *testing.T) { -// resetTest(t) -// h := http.Header{} -// for range parseCacheControl(h) { -// t.Fatal("cacheControl should be empty") -// } -// -// h.Set("cache-control", "no-cache") -// { -// cc := parseCacheControl(h) -// if _, ok := cc["foo"]; ok { -// t.Error(`Value "foo" shouldn't exist`) -// } -// noCache, ok := cc["no-cache"] -// if !ok { -// t.Fatalf(`"no-cache" value isn't set`) -// } -// if noCache != "" { -// t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) -// } -// } -// h.Set("cache-control", "no-cache, max-age=3600") -// { -// cc := parseCacheControl(h) -// noCache, ok := cc["no-cache"] -// if !ok { -// t.Fatalf(`"no-cache" value isn't set`) -// } -// if noCache != "" { -// t.Fatalf(`"no-cache" value isn't blank: %v`, noCache) -// } -// if cc["max-age"] != "3600" { -// t.Fatalf(`"max-age" value isn't "3600": %v`, cc["max-age"]) -// } -// } -//} -// -//func TestNoCacheRequestExpiration(t *testing.T) { -// resetTest(t) -// respHeaders := http.Header{} -// respHeaders.Set("Cache-Control", "max-age=7200") -// -// reqHeaders := http.Header{} -// reqHeaders.Set("Cache-Control", "no-cache") -// if getFreshness(respHeaders, reqHeaders) != transparent { -// t.Fatal("freshness isn't transparent") -// } -//} -// -//func TestNoCacheResponseExpiration(t *testing.T) { -// resetTest(t) -// respHeaders := http.Header{} -// respHeaders.Set("Cache-Control", "no-cache") -// respHeaders.Set("Expires", "Wed, 19 Apr 3000 11:43:00 GMT") -// -// reqHeaders := http.Header{} -// if getFreshness(respHeaders, reqHeaders) != stale { -// t.Fatal("freshness isn't stale") -// } -//} -// -//func TestReqMustRevalidate(t *testing.T) { -// resetTest(t) -// // not paying attention to request setting max-stale means never returning stale -// // responses, so always acting as if must-revalidate is set -// respHeaders := http.Header{} -// -// reqHeaders := http.Header{} -// reqHeaders.Set("Cache-Control", "must-revalidate") -// if getFreshness(respHeaders, reqHeaders) != stale { -// t.Fatal("freshness isn't stale") -// } -//} -// -//func TestRespMustRevalidate(t *testing.T) { -// resetTest(t) -// respHeaders := http.Header{} -// respHeaders.Set("Cache-Control", "must-revalidate") -// -// reqHeaders := http.Header{} -// if getFreshness(respHeaders, reqHeaders) != stale { -// t.Fatal("freshness isn't stale") -// } -//} -// -//func TestFreshExpiration(t *testing.T) { -// resetTest(t) -// now := time.Now() -// respHeaders := http.Header{} -// respHeaders.Set("date", now.Format(time.RFC1123)) -// respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) -// -// reqHeaders := http.Header{} -// if getFreshness(respHeaders, reqHeaders) != fresh { -// t.Fatal("freshness isn't fresh") -// } -// -// clock = &fakeClock{elapsed: 3 * time.Second} -// if getFreshness(respHeaders, reqHeaders) != stale { -// t.Fatal("freshness isn't stale") -// } -//} -// -//func TestMaxAge(t *testing.T) { -// resetTest(t) -// now := time.Now() -// respHeaders := http.Header{} -// respHeaders.Set("date", now.Format(time.RFC1123)) -// respHeaders.Set("cache-control", "max-age=2") -// -// reqHeaders := http.Header{} -// if getFreshness(respHeaders, reqHeaders) != fresh { -// t.Fatal("freshness isn't fresh") -// } -// -// clock = &fakeClock{elapsed: 3 * time.Second} -// if getFreshness(respHeaders, reqHeaders) != stale { -// t.Fatal("freshness isn't stale") -// } -//} -// -//func TestMaxAgeZero(t *testing.T) { -// resetTest(t) -// now := time.Now() -// respHeaders := http.Header{} -// respHeaders.Set("date", now.Format(time.RFC1123)) -// respHeaders.Set("cache-control", "max-age=0") -// -// reqHeaders := http.Header{} -// if getFreshness(respHeaders, reqHeaders) != stale { -// t.Fatal("freshness isn't stale") -// } -//} -// -//func TestBothMaxAge(t *testing.T) { -// resetTest(t) -// now := time.Now() -// respHeaders := http.Header{} -// respHeaders.Set("date", now.Format(time.RFC1123)) -// respHeaders.Set("cache-control", "max-age=2") -// -// reqHeaders := http.Header{} -// reqHeaders.Set("cache-control", "max-age=0") -// if getFreshness(respHeaders, reqHeaders) != stale { -// t.Fatal("freshness isn't stale") -// } -//} -// -//func TestMinFreshWithExpires(t *testing.T) { -// resetTest(t) -// now := time.Now() -// respHeaders := http.Header{} -// respHeaders.Set("date", now.Format(time.RFC1123)) -// respHeaders.Set("expires", now.Add(time.Duration(2)*time.Second).Format(time.RFC1123)) -// -// reqHeaders := http.Header{} -// reqHeaders.Set("cache-control", "min-fresh=1") -// if getFreshness(respHeaders, reqHeaders) != fresh { -// t.Fatal("freshness isn't fresh") -// } -// -// reqHeaders = http.Header{} -// reqHeaders.Set("cache-control", "min-fresh=2") -// if getFreshness(respHeaders, reqHeaders) != stale { -// t.Fatal("freshness isn't stale") -// } -//} -// -//func TestEmptyMaxStale(t *testing.T) { -// resetTest(t) -// now := time.Now() -// respHeaders := http.Header{} -// respHeaders.Set("date", now.Format(time.RFC1123)) -// respHeaders.Set("cache-control", "max-age=20") -// -// reqHeaders := http.Header{} -// reqHeaders.Set("cache-control", "max-stale") -// clock = &fakeClock{elapsed: 10 * time.Second} -// if getFreshness(respHeaders, reqHeaders) != fresh { -// t.Fatal("freshness isn't fresh") -// } -// -// clock = &fakeClock{elapsed: 60 * time.Second} -// if getFreshness(respHeaders, reqHeaders) != fresh { -// t.Fatal("freshness isn't fresh") -// } -//} -// -//func TestMaxStaleValue(t *testing.T) { -// resetTest(t) -// now := time.Now() -// respHeaders := http.Header{} -// respHeaders.Set("date", now.Format(time.RFC1123)) -// respHeaders.Set("cache-control", "max-age=10") -// -// reqHeaders := http.Header{} -// reqHeaders.Set("cache-control", "max-stale=20") -// clock = &fakeClock{elapsed: 5 * time.Second} -// if getFreshness(respHeaders, reqHeaders) != fresh { -// t.Fatal("freshness isn't fresh") -// } -// -// clock = &fakeClock{elapsed: 15 * time.Second} -// if getFreshness(respHeaders, reqHeaders) != fresh { -// t.Fatal("freshness isn't fresh") -// } -// -// clock = &fakeClock{elapsed: 30 * time.Second} -// if getFreshness(respHeaders, reqHeaders) != stale { -// t.Fatal("freshness isn't stale") -// } -//} -// -//func containsHeader(headers []string, header string) bool { -// for _, v := range headers { -// if http.CanonicalHeaderKey(v) == http.CanonicalHeaderKey(header) { -// return true -// } -// } -// return false -//} -// -//func TestGetEndToEndHeaders(t *testing.T) { -// resetTest(t) -// var ( -// headers http.Header -// end2end []string -// ) -// -// headers = http.Header{} -// headers.Set("content-type", "text/html") -// headers.Set("te", "deflate") -// -// end2end = getEndToEndHeaders(headers) -// if !containsHeader(end2end, "content-type") { -// t.Fatal(`doesn't contain "content-type" header`) -// } -// if containsHeader(end2end, "te") { -// t.Fatal(`doesn't contain "te" header`) -// } -// -// headers = http.Header{} -// headers.Set("connection", "content-type") -// headers.Set("content-type", "text/csv") -// headers.Set("te", "deflate") -// end2end = getEndToEndHeaders(headers) -// if containsHeader(end2end, "connection") { -// t.Fatal(`doesn't contain "connection" header`) -// } -// if containsHeader(end2end, "content-type") { -// t.Fatal(`doesn't contain "content-type" header`) -// } -// if containsHeader(end2end, "te") { -// t.Fatal(`doesn't contain "te" header`) -// } -// -// headers = http.Header{} -// end2end = getEndToEndHeaders(headers) -// if len(end2end) != 0 { -// t.Fatal(`non-zero end2end headers`) -// } -// -// headers = http.Header{} -// headers.Set("connection", "content-type") -// end2end = getEndToEndHeaders(headers) -// if len(end2end) != 0 { -// t.Fatal(`non-zero end2end headers`) -// } -//} -// -//type transportMock struct { -// response *http.Response -// err error -//} -// -//func (t transportMock) RoundTrip(req *http.Request) (resp *http.Response, err error) { -// return t.response, t.err -//} - -// -//func TestStaleIfErrorRequest(t *testing.T) { -// resetTest(t) -// now := time.Now() -// tmock := transportMock{ -// response: &http.Response{ -// Status: http.StatusText(http.StatusOK), -// StatusCode: http.StatusOK, -// Header: http.Header{ -// "Date": []string{now.Format(time.RFC1123)}, -// "Cache-Control": []string{"no-cache"}, -// }, -// Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), -// }, -// err: nil, -// } -// tp := newTestTransport(t.TempDir()) -// tp.Download = &tmock -// -// // First time, response is cached on success -// r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) -// r.Header.Set("Cache-Control", "stale-if-error") -// resp, err := tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// -// // On failure, response is returned from the cache -// tmock.response = nil -// tmock.err = errors.New("some error") -// resp, err = tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -//} -// -//func TestStaleIfErrorRequestLifetime(t *testing.T) { -// resetTest(t) -// now := time.Now() -// tmock := transportMock{ -// response: &http.Response{ -// Status: http.StatusText(http.StatusOK), -// StatusCode: http.StatusOK, -// Header: http.Header{ -// "Date": []string{now.Format(time.RFC1123)}, -// "Cache-Control": []string{"no-cache"}, -// }, -// Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), -// }, -// err: nil, -// } -// tp := newTestTransport(t.TempDir()) -// tp.Download = &tmock -// -// // First time, response is cached on success -// r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) -// r.Header.Set("Cache-Control", "stale-if-error=100") -// resp, err := tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// -// // On failure, response is returned from the cache -// tmock.response = nil -// tmock.err = errors.New("some error") -// resp, err = tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -// -// // Same for http errors -// tmock.response = &http.Response{StatusCode: http.StatusInternalServerError} -// tmock.err = nil -// resp, err = tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -// -// // If failure last more than max stale, error is returned -// clock = &fakeClock{elapsed: 200 * time.Second} -// _, err = tp.RoundTrip(r) -// if err != tmock.err { -// t.Fatalf("got err %v, want %v", err, tmock.err) -// } -//} -// -//func TestStaleIfErrorResponse(t *testing.T) { -// resetTest(t) -// now := time.Now() -// tmock := transportMock{ -// response: &http.Response{ -// Status: http.StatusText(http.StatusOK), -// StatusCode: http.StatusOK, -// Header: http.Header{ -// "Date": []string{now.Format(time.RFC1123)}, -// "Cache-Control": []string{"no-cache, stale-if-error"}, -// }, -// Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), -// }, -// err: nil, -// } -// tp := newTestTransport(t.TempDir()) -// tp.Download = &tmock -// -// // First time, response is cached on success -// r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) -// resp, err := tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// -// // On failure, response is returned from the cache -// tmock.response = nil -// tmock.err = errors.New("some error") -// resp, err = tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -//} -// -//func TestStaleIfErrorResponseLifetime(t *testing.T) { -// resetTest(t) -// now := time.Now() -// tmock := transportMock{ -// response: &http.Response{ -// Status: http.StatusText(http.StatusOK), -// StatusCode: http.StatusOK, -// Header: http.Header{ -// "Date": []string{now.Format(time.RFC1123)}, -// "Cache-Control": []string{"no-cache, stale-if-error=100"}, -// }, -// Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), -// }, -// err: nil, -// } -// tp := newTestTransport(t.TempDir()) -// tp.Download = &tmock -// -// // First time, response is cached on success -// r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) -// resp, err := tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// -// // On failure, response is returned from the cache -// tmock.response = nil -// tmock.err = errors.New("some error") -// resp, err = tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -// -// // If failure last more than max stale, error is returned -// clock = &fakeClock{elapsed: 200 * time.Second} -// _, err = tp.RoundTrip(r) -// if err != tmock.err { -// t.Fatalf("got err %v, want %v", err, tmock.err) -// } -//} -// -//// This tests the fix for https://github.com/gregjones/httpcache/issues/74. -//// Previously, after a stale response was used after encountering an error, -//// its StatusCode was being incorrectly replaced. -//func TestStaleIfErrorKeepsStatus(t *testing.T) { -// resetTest(t) -// now := time.Now() -// tmock := transportMock{ -// response: &http.Response{ -// Status: http.StatusText(http.StatusNotFound), -// StatusCode: http.StatusNotFound, -// Header: http.Header{ -// "Date": []string{now.Format(time.RFC1123)}, -// "Cache-Control": []string{"no-cache"}, -// }, -// Body: ioutil.NopCloser(bytes.NewBuffer([]byte("some data"))), -// }, -// err: nil, -// } -// tp := newTestTransport(t.TempDir()) -// tp.Download = &tmock -// -// // First time, response is cached on success -// r, _ := http.NewRequest("GET", "http://somewhere.com/", nil) -// r.Header.Set("Cache-Control", "stale-if-error") -// resp, err := tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -// _, err = ioutil.ReadAll(resp.Body) -// if err != nil { -// t.Fatal(err) -// } -// -// // On failure, response is returned from the cache -// tmock.response = nil -// tmock.err = errors.New("some error") -// resp, err = tp.RoundTrip(r) -// if err != nil { -// t.Fatal(err) -// } -// if resp == nil { -// t.Fatal("resp is nil") -// } -// if resp.StatusCode != http.StatusNotFound { -// t.Fatalf("Status wasn't 404: %d", resp.StatusCode) -// } -//} -// -//// Test that http.Client.Timeout is respected when cache transport is used. -//// That is so as long as request cancellation is propagated correctly. -//// In the past, that required CancelRequest to be implemented correctly, -//// but modern http.Client uses Request.Cancel (or request context) instead, -//// so we don't have to do anything. -//func TestClientTimeout(t *testing.T) { -// if testing.Short() { -// t.Skip("skipping timeout test in short mode") // Because it takes at least 3 seconds to run. -// } -// resetTest(t) -// -// client := &http.Client{ -// Download: newTestTransport(t.TempDir()), -// Timeout: time.Second, -// } -// started := time.Now() -// resp, err := client.Get(s.server.URL + "/3seconds") -// taken := time.Since(started) -// if err == nil { -// t.Error("got nil error, want timeout error") -// } -// if resp != nil { -// t.Error("got non-nil resp, want nil resp") -// } -// if taken >= 2*time.Second { -// t.Error("client.Do took 2+ seconds, want < 2 seconds") -// } -//} diff --git a/libsq/core/lg/devlog/devlog_test.go b/libsq/core/lg/devlog/devlog_test.go deleted file mode 100644 index b47c1a668..000000000 --- a/libsq/core/lg/devlog/devlog_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package devlog_test - -import ( - "log/slog" - "os" - "testing" - - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/lg/lgt" -) - -func TestDevlog(t *testing.T) { - log := lgt.New(t) - err := errz.New("oh noes") - log.Error("bah", lga.Err, err) -} - -func TestDevlogTextHandler(t *testing.T) { - o := &slog.HandlerOptions{} - - h := slog.NewTextHandler(os.Stdout, o) - log := slog.New(h) - err := errz.New("oh noes") - log.Error("bah", lga.Err, err) -} diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index f0d8f602f..3d24d7956 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -259,11 +259,6 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, log.Debug("Using cache dir", lga.Path, cacheDir) - ingestFilePath, err := gs.files.Filepath(ctx, src) - if err != nil { - return nil, err - } - var impl Grip var foundCached bool if impl, foundCached, err = gs.openCachedFor(ctx, src); err != nil { @@ -298,6 +293,11 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, log.Info("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) + ingestFilePath, err := gs.files.Filepath(ctx, src) + if err != nil { + return nil, err + } + // Write the checksums file. var sum checksum.Checksum if sum, err = checksum.ForFile(ingestFilePath); err != nil { diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 2be8254f5..c4245ca1c 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -53,6 +53,13 @@ func (fs *Files) CacheDirFor(src *Source) (dir string, err error) { return dir, nil } +// downloadCacheDirFor gets the download cache dir for loc. It is not guaranteed +// that the returned dir exists or is accessible. +func (fs *Files) downloadCacheDirFor(loc string) (dir string) { + fp := filepath.Join(fs.cacheDir, "downloads", checksum.Sum([]byte(loc))) + return fp +} + // sourceHash generates a hash for src. The hash is based on the // member fields of src, with special handling for src.Options. // Only the opts that affect data ingestion (options.TagIngestMutate) diff --git a/libsq/source/detect.go b/libsq/source/detect.go index b2fd53f9a..1d8da488f 100644 --- a/libsq/source/detect.go +++ b/libsq/source/detect.go @@ -39,7 +39,7 @@ func (fs *Files) AddDriverDetectors(detectFns ...DriverDetectFunc) { // DriverType returns the driver type of loc. // This may result in loading files into the cache. -func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, error) { +func (fs *Files) DriverType(ctx context.Context, handle string, loc string) (drivertype.Type, error) { log := lg.FromContext(ctx).With(lga.Loc, loc) ploc, err := parseLoc(loc) if err != nil { @@ -67,7 +67,7 @@ func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, e fs.mu.Lock() defer fs.mu.Unlock() // Fall back to the byte detectors - typ, ok, err := fs.detectType(ctx, loc) + typ, ok, err := fs.detectType(ctx, handle, loc) if err != nil { return drivertype.None, err } @@ -79,7 +79,7 @@ func (fs *Files) DriverType(ctx context.Context, loc string) (drivertype.Type, e return typ, nil } -func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Type, ok bool, err error) { +func (fs *Files) detectType(ctx context.Context, handle string, loc string) (typ drivertype.Type, ok bool, err error) { if len(fs.detectFns) == 0 { return drivertype.None, false, nil } @@ -87,7 +87,7 @@ func (fs *Files) detectType(ctx context.Context, loc string) (typ drivertype.Typ start := time.Now() openFn := func(ctx context.Context) (io.ReadCloser, error) { - return fs.newReader(ctx, loc) + return fs.newReader(ctx, handle, loc) } // We do the magic number first, because it's so fast. @@ -220,7 +220,7 @@ func (fs *Files) DetectStdinType(ctx context.Context) (drivertype.Type, error) { return drivertype.None, errz.New("must invoke Files.AddStdin before invoking DetectStdinType") } - typ, ok, err := fs.detectType(ctx, StdinHandle) + typ, ok, err := fs.detectType(ctx, "", StdinHandle) if err != nil { return drivertype.None, err } diff --git a/libsq/source/download.go b/libsq/source/download.go index c99e6cfa0..fe89f1502 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -28,7 +28,8 @@ import ( ) var OptHTTPPingTimeout = options.NewDuration( - "http.ping.timeout", + // FIXME: apply OptHTTPPingTimeout to httpz.NewClient invocations + "https.ping.timeout", "", 0, time.Second*10, @@ -39,8 +40,9 @@ not affected by this option. Example: 500ms or 3s.`, options.TagSource, ) -var OptHTTPSkipVerify = options.NewBool( - "http.skip-verify", +var OptHTTPSInsecureSkipVerify = options.NewBool( + // FIXME: apply OptHTTPSkipVerify to httpz.NewClient invocations + "https.skip-verify", "", false, 0, @@ -379,10 +381,6 @@ func fetchHTTPResponse(ctx context.Context, c *http.Client, u string) (resp *htt return resp, nil } -func getRemoteChecksum(ctx context.Context, u string) (string, error) { - return "", errz.New("not implemented") -} - // fetch ensures that loc exists locally as a file. This may // entail downloading the file via HTTPS etc. func (fs *Files) fetch(ctx context.Context, loc string) (fpath string, err error) { diff --git a/libsq/source/files.go b/libsq/source/files.go index 96897dd05..bd60d51fa 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -2,6 +2,8 @@ package source import ( "context" + "github.com/neilotoole/sq/libsq/core/ioz/download" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "io" "log/slog" "net/url" @@ -47,6 +49,7 @@ type Files struct { tempDir string clnup *cleanup.Cleanup optRegistry *options.Registry + coll *Collection // fscache is used to cache files, providing convenient access // to multiple readers via Files.newReader. @@ -64,9 +67,7 @@ type Files struct { // NewFiles returns a new Files instance. If cleanFscache is true, the fscache // is cleaned on Files.Close. -func NewFiles(ctx context.Context, optReg *options.Registry, - tmpDir, cacheDir string, cleanFscache bool, -) (*Files, error) { +func NewFiles(ctx context.Context, coll *Collection, optReg *options.Registry, tmpDir, cacheDir string, cleanFscache bool) (*Files, error) { log := lg.FromContext(ctx) log.Debug("Creating new Files instance", "tmp_dir", tmpDir, "cache_dir", cacheDir) if tmpDir == "" { @@ -81,6 +82,7 @@ func NewFiles(ctx context.Context, optReg *options.Registry, } fs := &Files{ + coll: coll, optRegistry: optReg, cacheDir: cacheDir, fscacheEntryMetas: make(map[string]*fscacheEntryMeta), @@ -263,7 +265,7 @@ func (fs *Files) addRegularFile(ctx context.Context, f *os.File, key string) (fs log.Debug("Adding regular file", lga.Key, key, lga.Path, f.Name()) if strings.Contains(f.Name(), "cached.db") { - log.Error("oh no, shouldn't be happening") + log.Error("oh no, shouldn't be happening") // FIXME: delete this } defer lg.WarnIfCloseError(log, lgm.CloseFileReader, f) @@ -286,6 +288,67 @@ func (fs *Files) addRegularFile(ctx context.Context, f *os.File, key string) (fs return r, errz.Err(err) } +func (fs *Files) addRemoteFile(ctx context.Context, handle, loc string) (io.ReadCloser, error) { + dlDir := fs.downloadCacheDirFor(loc) + if err := ioz.RequireDir(dlDir); err != nil { + return nil, err + } + + errCh := make(chan error, 1) + rdrCh := make(chan io.ReadCloser, 1) + + h := download.Handler{ + Cached: func(fp string) { + if err := fs.fscache.MapFile(fp); err != nil { + errCh <- errz.Wrapf(err, "failed to map file into fscache: %s", fp) + return + } + + r, _, err := fs.fscache.Get(fp) + if err != nil { + errCh <- errz.Err(err) + return + } + rdrCh <- r + }, + Uncached: func() (dest ioz.WriteErrorCloser) { + r, w, err := fs.fscache.Get(loc) + if err != nil { + errCh <- errz.Err(err) + return nil + } + + wec := ioz.NewFuncWriteErrorCloser(w, func(err error) { + log := lg.FromContext(ctx) + lg.WarnIfError(log, "Remove damaged cache entry", fs.fscache.Remove(loc)) + }) + + rdrCh <- r + return wec + }, + Error: func(err error) { + errCh <- err + }, + } + + c := httpz.NewDefaultClient() + dl, err := download.New(handle, c, loc, dlDir) + if err != nil { + return nil, err + } + + go dl.Get(ctx, h) + + select { + case <-ctx.Done(): + return nil, errz.Err(ctx.Err()) + case err = <-errCh: + return nil, err + case rdr := <-rdrCh: + return rdr, nil + } +} + // Filepath returns the file path of src.Location. // An error is returned the source's driver type // is not a file type (i.e. it is a SQL driver). @@ -319,7 +382,7 @@ func (fs *Files) Open(ctx context.Context, src *Source) (io.ReadCloser, error) { defer fs.mu.Unlock() lg.FromContext(ctx).Debug("Files.Open", lga.Src, src) - return fs.newReader(ctx, src.Location) + return fs.newReader(ctx, src.Handle, src.Location) } // CacheLockFor returns the lock file for src's cache. @@ -344,8 +407,8 @@ func (fs *Files) OpenFunc(src *Source) FileOpenFunc { } } -func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, error) { - log := lg.FromContext(ctx).With(lga.Loc, loc) +func (fs *Files) newReader(ctx context.Context, handle string, loc string) (io.ReadCloser, error) { + //log := lg.FromContext(ctx).With(lga.Loc, loc) locTyp := getLocType(loc) switch locTyp { @@ -390,42 +453,43 @@ func (fs *Files) newReader(ctx context.Context, loc string) (io.ReadCloser, erro return r, nil } - // It's an uncached remote file. - - if loc == StdinHandle { - r, w, err := fs.fscache.Get(StdinHandle) - log.Debug("Returned from fs.fcache.Get", lga.Err, err) - if err != nil { - return nil, errz.Err(err) - } - if w != nil { - return nil, errz.New("@stdin not cached: has AddStdin been invoked yet?") - } - - return r, nil - } - - if !fs.fscache.Exists(loc) { - r, _, err := fs.fscache.Get(loc) - if err != nil { - return nil, err - } - - return r, nil - } - - // cache miss - f, err := fs.openLocation(ctx, loc) - if err != nil { - return nil, err - } - - // Note that addRegularFile closes f - r, err := fs.addRegularFile(ctx, f, loc) - if err != nil { - return nil, err - } - return r, nil + //if loc == StdinHandle { + // r, w, err := fs.fscache.Get(StdinHandle) + // log.Debug("Returned from fs.fcache.Get", lga.Err, err) + // if err != nil { + // return nil, errz.Err(err) + // } + // if w != nil { + // return nil, errz.New("@stdin not cached: has AddStdin been invoked yet?") + // } + // + // return r, nil + //} + + //// It's an uncached remote file. + //if !fs.fscache.Exists(loc) { + // r, _, err := fs.fscache.Get(loc) + // if err != nil { + // return nil, err + // } + // + // return r, nil + //} + // + //// cache miss + //f, err := fs.openLocation(ctx, loc) + //if err != nil { + // return nil, err + //} + // + //// Note that addRegularFile closes f + //r, err := fs.addRegularFile(ctx, f, loc) + //if err != nil { + // return nil, err + //} + //return r, nil + + return fs.addRemoteFile(ctx, handle, loc) } // openLocation returns a file for loc. It is the caller's diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index e5b22c2b0..12d65d62e 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -56,11 +56,18 @@ func TestFiles_Type(t *testing.T) { t.Run(tu.Name(source.RedactLocation(tc.loc)), func(t *testing.T) { ctx := lg.NewContext(context.Background(), lgt.New(t)) - fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) + fs, err := source.NewFiles( + ctx, + &source.Collection{}, + nil, + tu.TempDir(t), + tu.CacheDir(t), + true, + ) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) - gotType, gotErr := fs.DriverType(context.Background(), tc.loc) + gotType, gotErr := fs.DriverType(context.Background(), "@test_"+stringz.Uniq8(), tc.loc) if tc.wantErr { require.Error(t, gotErr) return @@ -98,7 +105,14 @@ func TestFiles_DetectType(t *testing.T) { t.Run(filepath.Base(tc.loc), func(t *testing.T) { ctx := lg.NewContext(context.Background(), lgt.New(t)) - fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) + fs, err := source.NewFiles( + ctx, + &source.Collection{}, + nil, + tu.TempDir(t), + tu.CacheDir(t), + true, + ) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -158,7 +172,14 @@ func TestFiles_NewReader(t *testing.T) { Location: proj.Abs(fpath), } - fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) + fs, err := source.NewFiles( + ctx, + &source.Collection{}, + nil, + tu.TempDir(t), + tu.CacheDir(t), + true, + ) require.NoError(t, err) g := &errgroup.Group{} diff --git a/libsq/source/internal_test.go b/libsq/source/internal_test.go index d28a84aef..71ffc70c0 100644 --- a/libsq/source/internal_test.go +++ b/libsq/source/internal_test.go @@ -21,7 +21,7 @@ import ( // Export for testing. var ( FilesDetectTypeFn = func(fs *Files, ctx context.Context, loc string) (typ drivertype.Type, ok bool, err error) { - return fs.detectType(ctx, loc) + return fs.detectType(ctx, "", loc) } GroupsFilterOnlyDirectChildren = groupsFilterOnlyDirectChildren ) @@ -29,7 +29,14 @@ var ( func TestFiles_Open(t *testing.T) { ctx := lg.NewContext(context.Background(), lgt.New(t)) - fs, err := NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) + fs, err := NewFiles( + ctx, + &Collection{}, + nil, + tu.TempDir(t), + tu.CacheDir(t), + true, + ) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, fs.Close()) }) diff --git a/testh/testh.go b/testh/testh.go index 54fbab14c..ea441f906 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -155,9 +155,16 @@ func (h *Helper) init() { cli.RegisterDefaultOpts(optRegistry) h.registry = driver.NewRegistry(log) + cfg := config.New() var err error - h.files, err = source.NewFiles(h.Context, optRegistry, - tu.TempDir(h.T), tu.CacheDir(h.T), true) + h.files, err = source.NewFiles( + h.Context, + cfg.Collection, + optRegistry, + tu.TempDir(h.T), + tu.CacheDir(h.T), + true, + ) require.NoError(h.T, err) h.Cleanup.Add(func() { @@ -198,7 +205,7 @@ func (h *Helper) init() { Stdin: os.Stdin, Out: os.Stdout, ErrOut: os.Stdin, - Config: config.New(), + Config: cfg, ConfigStore: config.DiscardStore{}, DriverRegistry: h.registry, diff --git a/testh/testh_test.go b/testh/testh_test.go index ff7f7d9a6..0746a025a 100644 --- a/testh/testh_test.go +++ b/testh/testh_test.go @@ -137,7 +137,7 @@ func TestHelper_Files(t *testing.T) { th := testh.New(t) fs := th.Files() - typ, err := fs.DriverType(th.Context, src.Location) + typ, err := fs.DriverType(th.Context, src.Handle, src.Location) require.NoError(t, err) require.Equal(t, src.Type, typ) From fd4829f250a065eef094228c250cf19cd78ec887 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 7 Jan 2024 22:31:42 -0700 Subject: [PATCH 149/195] download seems to be working --- cli/cmd_add.go | 3 ++- drivers/csv/ingest.go | 9 +++------ libsq/core/progress/progress.go | 3 ++- libsq/source/detect.go | 4 ++-- libsq/source/files.go | 9 +++++---- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/cli/cmd_add.go b/cli/cmd_add.go index e5e5e0e68..8a5939b31 100644 --- a/cli/cmd_add.go +++ b/cli/cmd_add.go @@ -4,11 +4,12 @@ import ( "bytes" "context" "fmt" - "github.com/neilotoole/sq/libsq/core/stringz" "io" "os" "strings" + "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/spf13/cobra" "golang.org/x/term" diff --git a/drivers/csv/ingest.go b/drivers/csv/ingest.go index 3b8441603..fd2dcedfc 100644 --- a/drivers/csv/ingest.go +++ b/drivers/csv/ingest.go @@ -57,21 +57,18 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu log := lg.FromContext(ctx) startUTC := time.Now().UTC() - var err error - var r io.ReadCloser - - r, err = openFn(ctx) + rc, err := openFn(ctx) if err != nil { return err } - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) + defer lg.WarnIfCloseError(log, lgm.CloseFileReader, rc) delim, err := getDelimiter(src) if err != nil { return err } - cr := newCSVReader(r, delim) + cr := newCSVReader(rc, delim) recs, err := readRecords(cr, driver.OptIngestSampleSize.Get(src.Options)) if err != nil { return err diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 17dbb85ca..bf1a2dcc2 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -34,7 +34,8 @@ import ( // release. // // Deprecated: This is a temporary hack for testing. -const DebugDelay = time.Millisecond * 20 +// const DebugDelay = time.Millisecond * 20 +const DebugDelay = 0 type ctxKey struct{} diff --git a/libsq/source/detect.go b/libsq/source/detect.go index 1d8da488f..e2981269e 100644 --- a/libsq/source/detect.go +++ b/libsq/source/detect.go @@ -39,7 +39,7 @@ func (fs *Files) AddDriverDetectors(detectFns ...DriverDetectFunc) { // DriverType returns the driver type of loc. // This may result in loading files into the cache. -func (fs *Files) DriverType(ctx context.Context, handle string, loc string) (drivertype.Type, error) { +func (fs *Files) DriverType(ctx context.Context, handle, loc string) (drivertype.Type, error) { log := lg.FromContext(ctx).With(lga.Loc, loc) ploc, err := parseLoc(loc) if err != nil { @@ -79,7 +79,7 @@ func (fs *Files) DriverType(ctx context.Context, handle string, loc string) (dri return typ, nil } -func (fs *Files) detectType(ctx context.Context, handle string, loc string) (typ drivertype.Type, ok bool, err error) { +func (fs *Files) detectType(ctx context.Context, handle, loc string) (typ drivertype.Type, ok bool, err error) { if len(fs.detectFns) == 0 { return drivertype.None, false, nil } diff --git a/libsq/source/files.go b/libsq/source/files.go index bd60d51fa..f4c080608 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -2,8 +2,6 @@ package source import ( "context" - "github.com/neilotoole/sq/libsq/core/ioz/download" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" "io" "log/slog" "net/url" @@ -14,6 +12,9 @@ import ( "sync" "time" + "github.com/neilotoole/sq/libsq/core/ioz/download" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/fscache" "github.com/neilotoole/sq/libsq/core/cleanup" @@ -407,8 +408,8 @@ func (fs *Files) OpenFunc(src *Source) FileOpenFunc { } } -func (fs *Files) newReader(ctx context.Context, handle string, loc string) (io.ReadCloser, error) { - //log := lg.FromContext(ctx).With(lga.Loc, loc) +func (fs *Files) newReader(ctx context.Context, handle, loc string) (io.ReadCloser, error) { + // log := lg.FromContext(ctx).With(lga.Loc, loc) locTyp := getLocType(loc) switch locTyp { From fa0a1c1d1c0f0da48e353345292b5135c2890d28 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Tue, 9 Jan 2024 09:39:45 -0700 Subject: [PATCH 150/195] wip: mostly working --- cli/run.go | 9 +- cli/run/run.go | 2 +- libsq/core/ioz/download/download.go | 4 - libsq/core/progress/progress.go | 3 +- libsq/driver/grips.go | 108 +------- libsq/libsq.go | 4 +- libsq/pipeline.go | 14 +- libsq/query_no_src_test.go | 2 +- libsq/query_test.go | 2 +- libsq/source/cache.go | 136 ++++++++++ libsq/source/detect.go | 2 +- libsq/source/download.go | 397 +--------------------------- libsq/source/download_test.go | 196 ++++++-------- libsq/source/files.go | 145 ++++------ libsq/source/files_test.go | 27 +- libsq/source/internal_test.go | 45 +--- testh/testh.go | 11 +- 17 files changed, 310 insertions(+), 797 deletions(-) diff --git a/cli/run.go b/cli/run.go index 0dce98c49..d5ce9eee0 100644 --- a/cli/run.go +++ b/cli/run.go @@ -143,14 +143,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { var err error if ru.Files == nil { - ru.Files, err = source.NewFiles( - ctx, - ru.Config.Collection, - ru.OptionsRegistry, - source.DefaultTempDir(), - source.DefaultCacheDir(), - true, - ) + ru.Files, err = source.NewFiles(ctx, ru.OptionsRegistry, source.DefaultTempDir(), source.DefaultCacheDir(), true) if err != nil { lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run) return err diff --git a/cli/run/run.go b/cli/run/run.go index c2a68545b..76a861a3f 100644 --- a/cli/run/run.go +++ b/cli/run/run.go @@ -112,7 +112,7 @@ func (ru *Run) Close() error { func NewQueryContext(ru *Run, args map[string]string) *libsq.QueryContext { return &libsq.QueryContext{ Collection: ru.Config.Collection, - Sources: ru.Grips, + Grips: ru.Grips, Args: args, } } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 4b015922d..6227f9d26 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -103,15 +103,11 @@ type Download struct { // New returns a new Download for url that writes to cacheDir. // Name is a user-friendly name, such as a source handle like @data. // The name may show up in logs, or progress indicators etc. -// If c is nil, httpz.NewDefaultClient is used. func New(name string, c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Download, error) { _, err := url.ParseRequestURI(dlURL) if err != nil { return nil, errz.Wrap(err, "invalid download URL") } - if c == nil { - c = httpz.NewDefaultClient() - } if cacheDir, err = filepath.Abs(cacheDir); err != nil { return nil, errz.Err(err) diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index bf1a2dcc2..15d58a4e9 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -34,8 +34,7 @@ import ( // release. // // Deprecated: This is a temporary hack for testing. -// const DebugDelay = time.Millisecond * 20 -const DebugDelay = 0 +const DebugDelay = time.Millisecond * 0 type ctxKey struct{} diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 3d24d7956..c8806981d 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -3,7 +3,6 @@ package driver import ( "context" "log/slog" - "path/filepath" "strings" "sync" "time" @@ -11,7 +10,6 @@ import ( "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" @@ -130,11 +128,11 @@ func (gs *Grips) doOpen(ctx context.Context, src *source.Source) (Grip, error) { // OpenScratch returns a scratch database instance. It is not // necessary for the caller to close the returned Grip as -// its Close method will be invoked by d.Close. +// its Close method will be invoked by Grips.Close. func (gs *Grips) OpenScratch(ctx context.Context, src *source.Source) (Grip, error) { const msgCloseScratch = "Close scratch db" - cacheDir, srcCacheDBFilepath, _, err := gs.getCachePaths(src) + cacheDir, srcCacheDBFilepath, _, err := gs.files.CachePaths(src) if err != nil { return nil, err } @@ -225,8 +223,7 @@ func (gs *Grips) openIngestNoCache(ctx context.Context, src *source.Source, func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, ingestFn func(ctx context.Context, destGrip Grip) error, ) (Grip, error) { - log := lg.FromContext(ctx) - log = log.With(lga.Handle, src.Handle) + log := lg.FromContext(ctx).With(lga.Handle, src.Handle) ctx = lg.NewContext(ctx, log) lock, err := gs.files.CacheLockFor(src) @@ -252,13 +249,6 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, } }() - cacheDir, _, checksumsPath, err := gs.getCachePaths(src) - if err != nil { - return nil, err - } - - log.Debug("Using cache dir", lga.Path, cacheDir) - var impl Grip var foundCached bool if impl, foundCached, err = gs.openCachedFor(ctx, src); err != nil { @@ -293,99 +283,27 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, log.Info("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) - ingestFilePath, err := gs.files.Filepath(ctx, src) - if err != nil { - return nil, err - } - - // Write the checksums file. - var sum checksum.Checksum - if sum, err = checksum.ForFile(ingestFilePath); err != nil { - log.Warn("Failed to compute checksum for source file; caching not in effect", - lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) - return impl, nil //nolint:nilerr - } - - if err = checksum.WriteFile(checksumsPath, sum, ingestFilePath); err != nil { - log.Warn("Failed to write checksum; file caching not in effect", - lga.Src, src, lga.Dest, impl.Source(), lga.Path, ingestFilePath, lga.Err, err) + if err = gs.files.WriteIngestChecksum(ctx, src, impl.Source()); err != nil { + log.Warn("Failed to write checksum for source file; caching not in effect", + lga.Src, src, lga.Dest, impl.Source(), lga.Err, err) } return impl, nil } -// getCachePaths returns the paths to the cache files for src. -// There is no guarantee that these files exist, or are accessible. -// It's just the paths. -func (gs *Grips) getCachePaths(src *source.Source) (srcCacheDir, cacheDB, checksums string, err error) { - if srcCacheDir, err = gs.files.CacheDirFor(src); err != nil { - return "", "", "", err - } - - checksums = filepath.Join(srcCacheDir, "checksums.txt") - cacheDB = filepath.Join(srcCacheDir, "cached.db") - return srcCacheDir, cacheDB, checksums, nil -} - -func (gs *Grips) openCachedFor(ctx context.Context, src *source.Source) (Grip, bool, error) { - _, cacheDBPath, checksumsPath, err := gs.getCachePaths(src) - if err != nil { - return nil, false, err - } - - if !ioz.FileAccessible(checksumsPath) { - return nil, false, nil - } - - mChecksums, err := checksum.ReadFile(checksumsPath) - if err != nil { - return nil, false, err - } - - drvr, err := gs.drvrs.DriverFor(src.Type) - if err != nil { - return nil, false, err - } - - if drvr.DriverMetadata().IsSQL { - return nil, false, errz.Errorf("open file cache for source %s: driver {%s} is SQL, not document", - src.Handle, src.Type) - } - - // FIXME: Not too sure invoking files.Filepath here is the right approach? - srcFilepath, err := gs.files.Filepath(ctx, src) - if err != nil { - return nil, false, err - } - - cachedChecksum, ok := mChecksums[srcFilepath] - if !ok { - return nil, false, nil - } - - srcChecksum, err := checksum.ForFile(srcFilepath) +// openCachedFor returns the cached backing grip for src. +// If not cached, exists returns false. +func (gs *Grips) openCachedFor(ctx context.Context, src *source.Source) (backingGrip Grip, exists bool, err error) { + var backingSrc *source.Source + backingSrc, exists, err = gs.files.CachedBackingSourceFor(ctx, src) if err != nil { return nil, false, err } - - if srcChecksum != cachedChecksum { - return nil, false, nil - } - - // The checksums match, so we can use the cached DB, - // if it exists. - if !ioz.FileAccessible(cacheDBPath) { + if !exists { return nil, false, nil } - backingSrc := &source.Source{ - Handle: src.Handle + "_cached", - Location: "sqlite3://" + cacheDBPath, - Type: drivertype.Type("sqlite3"), - } - - backingGrip, err := gs.doOpen(ctx, backingSrc) - if err != nil { + if backingGrip, err = gs.doOpen(ctx, backingSrc); err != nil { return nil, false, errz.Wrapf(err, "open cached DB for source %s", src.Handle) } diff --git a/libsq/libsq.go b/libsq/libsq.go index 7f6477af8..cf38a44f2 100644 --- a/libsq/libsq.go +++ b/libsq/libsq.go @@ -27,8 +27,8 @@ type QueryContext struct { // Collection is the set of sources. Collection *source.Collection - // Sources bridges between source.Source and databases. - Sources *driver.Grips + // Grips mediates access to driver.Grip instances. + Grips *driver.Grips // Args defines variables that are substituted into the query. // May be nil or empty. diff --git a/libsq/pipeline.go b/libsq/pipeline.go index cecf01e9a..7be25a929 100644 --- a/libsq/pipeline.go +++ b/libsq/pipeline.go @@ -185,7 +185,7 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { if handle == "" { src = p.qc.Collection.Active() - if src == nil || !p.qc.Sources.IsSQLSource(src) { + if src == nil || !p.qc.Grips.IsSQLSource(src) { log.Debug("No active SQL source, will use scratchdb.") // REVISIT: Grips.OpenScratch needs a source, so we just make one up. ephemeralSrc := &source.Source{ @@ -195,7 +195,7 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { // FIXME: We really want to change the signature of OpenScratch to // just need a name, not a source. - p.targetGrip, err = p.qc.Sources.OpenScratch(ctx, ephemeralSrc) + p.targetGrip, err = p.qc.Grips.OpenScratch(ctx, ephemeralSrc) if err != nil { return err } @@ -214,7 +214,7 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { } // At this point, src is non-nil. - if p.targetGrip, err = p.qc.Sources.Open(ctx, src); err != nil { + if p.targetGrip, err = p.qc.Grips.Open(ctx, src); err != nil { return err } @@ -246,7 +246,7 @@ func (p *pipeline) prepareFromTable(ctx context.Context, tblSel *ast.TblSelector return "", nil, err } - fromGrip, err = p.qc.Sources.Open(ctx, src) + fromGrip, err = p.qc.Grips.Open(ctx, src) if err != nil { return "", nil, err } @@ -339,7 +339,7 @@ func (p *pipeline) joinSingleSource(ctx context.Context, jc *joinClause) (fromCl return "", nil, err } - fromGrip, err = p.qc.Sources.Open(ctx, src) + fromGrip, err = p.qc.Grips.Open(ctx, src) if err != nil { return "", nil, err } @@ -377,7 +377,7 @@ func (p *pipeline) joinCrossSource(ctx context.Context, jc *joinClause) (fromCla } // Open the join db - joinGrip, err := p.qc.Sources.OpenJoin(ctx, srcs...) + joinGrip, err := p.qc.Grips.OpenJoin(ctx, srcs...) if err != nil { return "", nil, err } @@ -404,7 +404,7 @@ func (p *pipeline) joinCrossSource(ctx context.Context, jc *joinClause) (fromCla return "", nil, err } var db driver.Grip - if db, err = p.qc.Sources.Open(ctx, src); err != nil { + if db, err = p.qc.Grips.Open(ctx, src); err != nil { return "", nil, err } diff --git a/libsq/query_no_src_test.go b/libsq/query_no_src_test.go index d41783b52..38ca9b255 100644 --- a/libsq/query_no_src_test.go +++ b/libsq/query_no_src_test.go @@ -34,7 +34,7 @@ func TestQuery_no_source(t *testing.T) { qc := &libsq.QueryContext{ Collection: coll, - Sources: sources, + Grips: sources, } gotSQL, gotErr := libsq.SLQ2SQL(th.Context, qc, tc.in) diff --git a/libsq/query_test.go b/libsq/query_test.go index 593844ad0..ea1bbeb0c 100644 --- a/libsq/query_test.go +++ b/libsq/query_test.go @@ -167,7 +167,7 @@ func doExecQueryTestCase(t *testing.T, tc queryTestCase) { qc := &libsq.QueryContext{ Collection: coll, - Sources: sources, + Grips: sources, Args: tc.args, } diff --git a/libsq/source/cache.go b/libsq/source/cache.go index c4245ca1c..e7031ec9b 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -2,12 +2,20 @@ package source import ( "bytes" + "context" "fmt" "os" "path/filepath" "strings" "time" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/source/drivertype" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/options" @@ -60,6 +68,134 @@ func (fs *Files) downloadCacheDirFor(loc string) (dir string) { return fp } +func (fs *Files) WriteIngestChecksum(ctx context.Context, src, backingSrc *Source) (err error) { + log := lg.FromContext(ctx) + ingestFilePath, err := fs.filepath(src) + if err != nil { + return err + } + + // Write the checksums file. + var sum checksum.Checksum + if sum, err = checksum.ForFile(ingestFilePath); err != nil { + log.Warn("Failed to compute checksum for source file; caching not in effect", + lga.Src, src, lga.Dest, backingSrc, lga.Path, ingestFilePath, lga.Err, err) + return err + } + + var checksumsPath string + if _, _, checksumsPath, err = fs.CachePaths(src); err != nil { + return err + } + + if err = checksum.WriteFile(checksumsPath, sum, ingestFilePath); err != nil { + log.Warn("Failed to write checksum; file caching not in effect", + lga.Src, src, lga.Dest, backingSrc, lga.Path, ingestFilePath, lga.Err, err) + } + return err +} + +// CachedBackingSourceFor returns the underlying backing source for src, if +// it exists. If it does not exist, ok returns false. +func (fs *Files) CachedBackingSourceFor(ctx context.Context, src *Source) (backingSrc *Source, ok bool, err error) { + fs.mu.Lock() + defer fs.mu.Unlock() + + switch getLocType(src.Location) { + case locTypeLocalFile: + return fs.cachedBackingSourceForLocalFile(ctx, src) + case locTypeRemoteFile: + return fs.cachedBackingSourceForRemoteFile(ctx, src) + default: + return nil, false, errz.Errorf("caching not applicable for source: %s", src.Handle) + } +} + +// cachedBackingSourceForLocalFile returns the underlying cached backing +// source for src, if it exists. +func (fs *Files) cachedBackingSourceForLocalFile(ctx context.Context, src *Source) (*Source, bool, error) { + _, cacheDBPath, checksumsPath, err := fs.CachePaths(src) + if err != nil { + return nil, false, err + } + + if !ioz.FileAccessible(checksumsPath) { + return nil, false, nil + } + + mChecksums, err := checksum.ReadFile(checksumsPath) + if err != nil { + return nil, false, err + } + + srcFilepath, err := fs.filepath(src) + if err != nil { + return nil, false, err + } + + cachedChecksum, ok := mChecksums[srcFilepath] + if !ok { + return nil, false, nil + } + + srcChecksum, err := checksum.ForFile(srcFilepath) + if err != nil { + return nil, false, err + } + + if srcChecksum != cachedChecksum { + return nil, false, nil + } + + // The checksums match, so we can use the cached DB, + // if it exists. + if !ioz.FileAccessible(cacheDBPath) { + return nil, false, nil + } + + backingSrc := &Source{ + Handle: src.Handle + "_cached", + Location: "sqlite3://" + cacheDBPath, + Type: drivertype.Type("sqlite3"), + } + + lg.FromContext(ctx).Debug("Found cached backing source DB src", lga.Src, src, "backing_src", backingSrc) + return backingSrc, true, nil +} + +// cachedBackingSourceForRemoteFile returns the underlying cached backing +// source for src, if it exists. +func (fs *Files) cachedBackingSourceForRemoteFile(ctx context.Context, src *Source) (*Source, bool, error) { + // src.Location is guaranteed to be a URL. + log := lg.FromContext(ctx) + + downloadedFile, r, err := fs.addRemoteFile(ctx, src.Handle, src.Location) + if err != nil { + return nil, false, err + } + lg.WarnIfCloseError(log, lgm.CloseFileReader, r) + if downloadedFile == "" { + log.Debug("No cached download file for src", lga.Src, src) + return nil, false, nil + } + + log.Debug("Found cached download file for src", lga.Src, src, lga.Path, downloadedFile) + return fs.cachedBackingSourceForLocalFile(ctx, src) +} + +// CachePaths returns the paths to the cache files for src. +// There is no guarantee that these files exist, or are accessible. +// It's just the paths. +func (fs *Files) CachePaths(src *Source) (srcCacheDir, cacheDB, checksums string, err error) { + if srcCacheDir, err = fs.CacheDirFor(src); err != nil { + return "", "", "", err + } + + checksums = filepath.Join(srcCacheDir, "checksums.txt") + cacheDB = filepath.Join(srcCacheDir, "cached.db") + return srcCacheDir, cacheDB, checksums, nil +} + // sourceHash generates a hash for src. The hash is based on the // member fields of src, with special handling for src.Options. // Only the opts that affect data ingestion (options.TagIngestMutate) diff --git a/libsq/source/detect.go b/libsq/source/detect.go index e2981269e..1a5f01140 100644 --- a/libsq/source/detect.go +++ b/libsq/source/detect.go @@ -220,7 +220,7 @@ func (fs *Files) DetectStdinType(ctx context.Context) (drivertype.Type, error) { return drivertype.None, errz.New("must invoke Files.AddStdin before invoking DetectStdinType") } - typ, ok, err := fs.detectType(ctx, "", StdinHandle) + typ, ok, err := fs.detectType(ctx, StdinHandle, StdinHandle) if err != nil { return drivertype.None, err } diff --git a/libsq/source/download.go b/libsq/source/download.go index fe89f1502..cbd215865 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -1,48 +1,27 @@ package source import ( - "bytes" - "context" - "io" - "log/slog" - "net/http" - "net/http/httputil" - "net/url" - "os" - "path/filepath" - "sync" "time" - "golang.org/x/exp/maps" - - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" - "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" - "github.com/neilotoole/sq/libsq/source/fetcher" ) var OptHTTPPingTimeout = options.NewDuration( // FIXME: apply OptHTTPPingTimeout to httpz.NewClient invocations - "https.ping.timeout", + "http.ping.timeout", "", 0, time.Second*10, - "HTTP ping timeout duration", - `How long to wait for initial response from HTTP endpoint before + "HTTP/S ping timeout duration", + `How long to wait for initial response from HTTP/S endpoint before timeout occurs. Long-running operations, such as HTTP file downloads, are not affected by this option. Example: 500ms or 3s.`, options.TagSource, ) var OptHTTPSInsecureSkipVerify = options.NewBool( - // FIXME: apply OptHTTPSkipVerify to httpz.NewClient invocations - "https.skip-verify", + // FIXME: apply OptHTTPSInsecureSkipVerify to httpz.NewClient invocations + "https.insecure-skip-verify", "", false, 0, @@ -50,369 +29,3 @@ var OptHTTPSInsecureSkipVerify = options.NewBool( "Skip HTTPS TLS verification", "Skip HTTPS TLS verification. Useful when downloading against self-signed certs.", ) - -// newDownloader creates a new downloader using cacheDir for the given url. -func newDownloader(c *http.Client, cacheDir, dlURL string) *downloader { - return &downloader{ - c: c, - cacheDir: cacheDir, - url: dlURL, - } -} - -// downloader is a helper for getting file contents from a URL, -// and caching the file locally. The structure of cacheDir -// is as follows: -// -// cacheDir/ -// pid.lock -// checksum.txt -// header.txt -// dl/ -// -// -// Let's take a closer look. -// -// - pid.lock is a lock file used to ensure that only one -// process is downloading the file at a time. -// FIXME: are we using pid.lock, or will we share the parent cache lock? -// -// - header.txt is a dump of the HTTP response header, included for -// debugging convenience. -// -// - checksum.txt contains a checksum:key pair, where the checksum is -// calculated using checksum.ForHTTPResponse, and the key is the path -// to the downloaded file, e.g. "dl/actor.csv". -// -// 67a47a0 dl/actor.csv -// -// - The file is downloaded to dl/ instead of into the root -// of cache dir, just to avoid the (remote) possibility of a name -// collision with the other files in cacheDir. The filename is based -// on the HTTP response, incorporating the Content-Disposition header -// if present, or the last path segment of the URL. The filename is -// sanitized. -// -// When downloader.Download is invoked, it appropriately clears the existing -// stored files before proceeding. Likewise, if the download fails, the stored -// files are wiped, to prevent a partial download from being used. -type downloader struct { - c *http.Client - mu sync.Mutex - cacheDir string - url string -} - -func (d *downloader) log(log *slog.Logger) *slog.Logger { - return log.With(lga.URL, d.url, lga.Dir, d.cacheDir) -} - -func (d *downloader) dlDir() string { - return filepath.Join(d.cacheDir, "dl") -} - -func (d *downloader) checksumFile() string { - return filepath.Join(d.cacheDir, "checksum.txt") -} - -func (d *downloader) headerFile() string { - return filepath.Join(d.cacheDir, "header.txt") -} - -// Download downloads the file at the URL to the download dir, creating the -// checksum file on completion, and also writes the file to dest, and returns -// the file path of the downloaded file. -// -// It is the caller's responsibility to close dest. If an appropriate file name -// cannot be determined from the HTTP response, the file is named "download". -// If the download fails at any stage, the download file is removed, but written -// always returns the number of bytes written to dest. -// Note that the download process is context-aware. -func (d *downloader) Download(ctx context.Context, dest io.Writer) (written int64, fp string, err error) { - d.mu.Lock() - defer d.mu.Unlock() - - log := d.log(lg.FromContext(ctx)) - - dlDir := d.dlDir() - // Clear the download dir. - if err = os.RemoveAll(dlDir); err != nil { - return written, "", errz.Wrapf(err, "could not clear download dir for: %s", d.url) - } - - if err = ioz.RequireDir(dlDir); err != nil { - return written, "", errz.Wrapf(err, "could not create download dir for: %s", d.url) - } - - // Make sure the header file is cleared. - if err = os.RemoveAll(d.headerFile()); err != nil { - return written, "", errz.Wrapf(err, "could not clear header file for: %s", d.url) - } - - var cancelFn context.CancelFunc - ctx, cancelFn = context.WithCancel(ctx) - defer cancelFn() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, d.url, nil) - if err != nil { - return written, "", errz.Wrapf(err, "download new request failed for: %s", d.url) - } - // setDefaultHTTPRequestHeaders(req) - - resp, err := d.c.Do(req) - if err != nil { - return written, "", errz.Wrapf(err, "download failed for: %s", d.url) - } - defer func() { - if resp != nil && resp.Body != nil { - lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - } - }() - - if err = d.writeHeaderFile(resp); err != nil { - return written, "", err - } - - if resp.StatusCode != http.StatusOK { - return written, "", errz.Errorf("download failed with %s for %s", resp.Status, d.url) - } - - filename := httpz.Filename(resp) - if filename == "" { - filename = "download" - } - - fp = filepath.Join(dlDir, filename) - f, err := os.OpenFile(fp, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) - if err != nil { - return written, "", errz.Wrapf(err, "could not create download file for: %s", d.url) - } - - written, err = io.Copy( - contextio.NewWriter(ctx, io.MultiWriter(f, dest)), - contextio.NewReader(ctx, resp.Body), - ) - if err != nil { - log.Error("failed to write download file", lga.File, fp, lga.URL, d.url, lga.Err, err) - lg.WarnIfCloseError(log, lgm.CloseFileWriter, f) - lg.WarnIfFuncError(log, lgm.RemoveFile, func() error { return errz.Err(os.Remove(fp)) }) - return written, "", err - } - - if err = f.Close(); err != nil { - lg.WarnIfFuncError(log, lgm.RemoveFile, func() error { return errz.Err(os.Remove(fp)) }) - return written, "", errz.Wrapf(err, "failed to close download file: %s", fp) - } - - if resp.ContentLength == -1 { - // Sometimes the response won't have the content-length set, but we know - // it via the number of bytes read from the body. We explicitly set - // it here, because checksum.ForHTTPResponse uses it. - resp.ContentLength = written - } - - sum := checksum.ForHTTPResponse(resp) - if err = checksum.WriteFile(d.checksumFile(), sum, filepath.Join("dl", filename)); err != nil { - lg.WarnIfFuncError(log, lgm.RemoveFile, func() error { return errz.Err(os.Remove(fp)) }) - } - - log.Info("Wrote download file", lga.Written, written, lga.File, fp) - return written, fp, nil -} - -func setDefaultHTTPRequestHeaders(req *http.Request) { - req.Header.Set("User-Agent", "sq") // FIXME: this should be set on the http.Client - req.Header.Set("Accept-Encoding", "gzip") -} - -func (d *downloader) writeHeaderFile(resp *http.Response) error { - b, err := httputil.DumpResponse(resp, false) - if err != nil { - return errz.Wrapf(err, "failed to dump HTTP response for: %s", d.url) - } - - if len(b) == 0 { - return errz.Errorf("empty HTTP response for: %s", d.url) - } - - // Add a custom field just for human consumption convenience. - b = bytes.TrimSuffix(b, []byte("\r\n")) - b = append(b, "X-Sq-Downloaded-From: "+d.url+"\r\n"...) - - if err = os.WriteFile(d.headerFile(), b, os.ModePerm); err != nil { - return errz.Wrapf(err, "failed to store HTTP response header for: %s", d.url) - } - return nil -} - -// ClearCache clears the cache dir. -func (d *downloader) ClearCache(ctx context.Context) error { - d.mu.Lock() - defer d.mu.Unlock() - - log := d.log(lg.FromContext(ctx)) - if err := os.RemoveAll(d.cacheDir); err != nil { - log.Error("Failed to clear cache dir", lga.Dir, d.cacheDir, lga.Err, err) - return errz.Wrapf(err, "failed to clear cache dir: %s", d.cacheDir) - } - - log.Info("Cleared cache dir", lga.Dir, d.cacheDir) - return nil -} - -// Cached returns true if the file is cached locally, and if so, also returns -// the checksum and file path of the cached file. -func (d *downloader) Cached(ctx context.Context) (ok bool, sum checksum.Checksum, fp string) { - d.mu.Lock() - defer d.mu.Unlock() - - log := d.log(lg.FromContext(ctx)) - dlDir := d.dlDir() - fi, err := os.Stat(dlDir) - if err != nil { - log.Debug("not cached: can't stat download dir") - return false, "", "" - } - if !fi.IsDir() { - log.Error("not cached: download dir is not a dir") - return false, "", "" - } - - if _, err = os.Stat(d.checksumFile()); err != nil { - log.Debug("not cached: can't stat download checksum file", lga.File, d.checksumFile()) - return false, "", "" - } - - checksums, err := checksum.ReadFile(d.checksumFile()) - if err != nil { - log.Debug("not cached: can't read download checksum file") - return false, "", "" - } - - if len(checksums) != 1 { - log.Debug("not cached: download checksum file has unexpected number of entries") - return false, "", "" - } - - key := maps.Keys(checksums)[0] - sum = checksums[key] - if len(sum) == 0 { - log.Debug("not cached: checksum file has empty checksum", lga.File, key) - return false, "", "" - } - - downloadFile := filepath.Join(d.cacheDir, key) - if _, err = os.Stat(downloadFile); err != nil { - log.Debug("not cached: can't stat download file referenced in checksum file", lga.File, key) - return false, "", "" - } - - log.Info("found cached file", lga.File, key) - return true, sum, downloadFile -} - -// CachedIsCurrent returns true if the file is cached locally and if its -// stored checksum matches the checksum of the remote file. -func (d *downloader) CachedIsCurrent(ctx context.Context) (ok bool, err error) { - ok, sum, _ := d.Cached(ctx) - if !ok { - return false, errz.Errorf("not cached: %s", d.url) - } - - resp, err := fetchHTTPResponse(ctx, d.c, d.url) - if err != nil { - return false, errz.Wrap(err, "check remote header") - } - - remoteSum := checksum.ForHTTPResponse(resp) - if sum != remoteSum { - return false, nil - } - - return true, nil -} - -// fetchHTTPResponse fetches the HTTP response for u. First HEAD is tried, and -// if that's not allowed (http.StatusMethodNotAllowed), then GET is used. -func fetchHTTPResponse(ctx context.Context, c *http.Client, u string) (resp *http.Response, err error) { - req, err := http.NewRequestWithContext(ctx, http.MethodHead, u, nil) - if err != nil { - return nil, errz.Err(err) - } - setDefaultHTTPRequestHeaders(req) - - resp, err = c.Do(req) - if err != nil { - return nil, errz.Err(err) - } - if resp.Body != nil { - _ = resp.Body.Close() - } - - switch resp.StatusCode { - default: - return nil, errz.Errorf("unexpected HTTP status (%s) for HEAD: %s", resp.Status, u) - case http.StatusOK: - return resp, nil - case http.StatusMethodNotAllowed: - } - - // HEAD not allowed, try GET - var cancelFn context.CancelFunc - ctx, cancelFn = context.WithCancel(ctx) - defer cancelFn() - req, err = http.NewRequestWithContext(ctx, http.MethodGet, u, nil) - if err != nil { - return nil, errz.Err(err) - } - // setDefaultHTTPRequestHeaders(req) - - resp, err = http.DefaultClient.Do(req) - if err != nil { - return nil, errz.Err(err) - } - if resp.Body != nil { - _ = resp.Body.Close() - } - - if resp.StatusCode != http.StatusOK { - return nil, errz.Errorf("unexpected HTTP status (%s) for GET: %s", resp.Status, u) - } - - return resp, nil -} - -// fetch ensures that loc exists locally as a file. This may -// entail downloading the file via HTTPS etc. -func (fs *Files) fetch(ctx context.Context, loc string) (fpath string, err error) { - // This impl is a vestigial abomination from an early - // experiment. - - var ok bool - if fpath, ok = isFpath(loc); ok { - // loc is already a local file path - return fpath, nil - } - - var u *url.URL - if u, ok = httpURL(loc); !ok { - return "", errz.Errorf("not a valid file location: %s", loc) - } - - var dlFile *os.File - dlFile, err = os.CreateTemp("", "") - if err != nil { - return "", errz.Err(err) - } - - fetchr := &fetcher.Fetcher{} - // TOOD: ultimately should be passing a real context here - err = fetchr.Fetch(ctx, u.String(), dlFile) - if err != nil { - return "", errz.Err(err) - } - - // dlFile is kept open until fs is closed. - fs.clnup.AddC(dlFile) - - return dlFile.Name(), nil -} diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go index 76cd49ff7..13a518ba3 100644 --- a/libsq/source/download_test.go +++ b/libsq/source/download_test.go @@ -1,35 +1,14 @@ package source -import ( - "bytes" - "context" - "net/http" - "net/http/httptest" - "net/url" - "path" - "path/filepath" - "strconv" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/neilotoole/sq/libsq/core/ioz/checksum" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lgt" - "github.com/neilotoole/sq/testh/proj" - "github.com/neilotoole/sq/testh/tu" -) - -func TestGetRemoteChecksum(t *testing.T) { - // sq add https://sq.io/testdata/actor.csv - // - // content-length: 7641 - // date: Thu, 07 Dec 2023 06:31:10 GMT - // etag: "069dbf690a12d5eb853feb8e04aeb49e-ssl" - - // TODO -} +//func TestGetRemoteChecksum(t *testing.T) { +// // sq add https://sq.io/testdata/actor.csv +// // +// // content-length: 7641 +// // date: Thu, 07 Dec 2023 06:31:10 GMT +// // etag: "069dbf690a12d5eb853feb8e04aeb49e-ssl" +// +// // TODO +//} const ( urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" @@ -38,81 +17,82 @@ const ( sizeGzipActorCSV = int64(1968) ) -func TestFetchHTTPHeader_sqio(t *testing.T) { - header, err := fetchHTTPResponse(context.Background(), http.DefaultClient, urlActorCSV) - require.NoError(t, err) - require.NotNil(t, header) - - // TODO -} - -func TestDownloader_Download(t *testing.T) { - ctx := lg.NewContext(context.Background(), lgt.New(t)) - const dlURL = urlActorCSV - const wantContentLength = sizeActorCSV - u, err := url.Parse(dlURL) - require.NoError(t, err) - wantFilename := path.Base(u.Path) - require.Equal(t, "actor.csv", wantFilename) - - cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-1")) - require.NoError(t, err) - t.Logf("cacheDir: %s", cacheDir) - - dl := newDownloader(http.DefaultClient, cacheDir, dlURL) - require.NoError(t, dl.ClearCache(ctx)) - - buf := &bytes.Buffer{} - written, cachedFp, err := dl.Download(ctx, buf) - require.NoError(t, err) - require.FileExists(t, cachedFp) - require.Equal(t, wantContentLength, written) - require.Equal(t, wantContentLength, int64(buf.Len())) - - s := tu.ReadFileToString(t, dl.headerFile()) - t.Logf("header.txt\n\n" + s + "\n") - - s = tu.ReadFileToString(t, dl.checksumFile()) - t.Logf("checksum.txt\n\n" + s + "\n") - - gotSums, err := checksum.ReadFile(dl.checksumFile()) - require.NoError(t, err) - - isCached, cachedSum, cachedFp := dl.Cached(ctx) - require.True(t, isCached) - wantKey := filepath.Join("dl", wantFilename) - wantFp, err := filepath.Abs(filepath.Join(dl.cacheDir, wantKey)) - require.NoError(t, err) - require.Equal(t, wantFp, cachedFp) - fileSum, ok := gotSums[wantKey] - require.True(t, ok) - require.Equal(t, cachedSum, fileSum) - - isCurrent, err := dl.CachedIsCurrent(ctx) - require.NoError(t, err) - require.True(t, isCurrent) -} - -func TestFetchHTTPHeader_HEAD_fallback_GET(t *testing.T) { - b := proj.ReadFile("drivers/csv/testdata/sakila-csv/actor.csv") - srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - w.WriteHeader(http.StatusMethodNotAllowed) - return - } - - w.Header().Set("Content-Length", strconv.Itoa(len(b))) - w.WriteHeader(http.StatusOK) - _, err := w.Write(b) - require.NoError(t, err) - })) - t.Cleanup(srvr.Close) - - u := srvr.URL - - resp, err := fetchHTTPResponse(context.Background(), http.DefaultClient, u) - assert.NoError(t, err) - assert.NotNil(t, resp) - require.Equal(t, http.StatusOK, resp.StatusCode) - require.Equal(t, len(b), int(resp.ContentLength)) -} +// +//func TestFetchHTTPHeader_sqio(t *testing.T) { +// header, err := fetchHTTPResponse(context.Background(), http.DefaultClient, urlActorCSV) +// require.NoError(t, err) +// require.NotNil(t, header) +// +// // TODO +//} +// +//func TestDownloader_Download(t *testing.T) { +// ctx := lg.NewContext(context.Background(), lgt.New(t)) +// const dlURL = urlActorCSV +// const wantContentLength = sizeActorCSV +// u, err := url.Parse(dlURL) +// require.NoError(t, err) +// wantFilename := path.Base(u.Path) +// require.Equal(t, "actor.csv", wantFilename) +// +// cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-1")) +// require.NoError(t, err) +// t.Logf("cacheDir: %s", cacheDir) +// +// dl := newDownloader(http.DefaultClient, cacheDir, dlURL) +// require.NoError(t, dl.ClearCache(ctx)) +// +// buf := &bytes.Buffer{} +// written, cachedFp, err := dl.Download(ctx, buf) +// require.NoError(t, err) +// require.FileExists(t, cachedFp) +// require.Equal(t, wantContentLength, written) +// require.Equal(t, wantContentLength, int64(buf.Len())) +// +// s := tu.ReadFileToString(t, dl.headerFile()) +// t.Logf("header.txt\n\n" + s + "\n") +// +// s = tu.ReadFileToString(t, dl.checksumFile()) +// t.Logf("checksum.txt\n\n" + s + "\n") +// +// gotSums, err := checksum.ReadFile(dl.checksumFile()) +// require.NoError(t, err) +// +// isCached, cachedSum, cachedFp := dl.Cached(ctx) +// require.True(t, isCached) +// wantKey := filepath.Join("dl", wantFilename) +// wantFp, err := filepath.Abs(filepath.Join(dl.cacheDir, wantKey)) +// require.NoError(t, err) +// require.Equal(t, wantFp, cachedFp) +// fileSum, ok := gotSums[wantKey] +// require.True(t, ok) +// require.Equal(t, cachedSum, fileSum) +// +// isCurrent, err := dl.CachedIsCurrent(ctx) +// require.NoError(t, err) +// require.True(t, isCurrent) +//} +// +//func TestFetchHTTPHeader_HEAD_fallback_GET(t *testing.T) { +// b := proj.ReadFile("drivers/csv/testdata/sakila-csv/actor.csv") +// srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// if r.Method != http.MethodGet { +// w.WriteHeader(http.StatusMethodNotAllowed) +// return +// } +// +// w.Header().Set("Content-Length", strconv.Itoa(len(b))) +// w.WriteHeader(http.StatusOK) +// _, err := w.Write(b) +// require.NoError(t, err) +// })) +// t.Cleanup(srvr.Close) +// +// u := srvr.URL +// +// resp, err := fetchHTTPResponse(context.Background(), http.DefaultClient, u) +// assert.NoError(t, err) +// assert.NotNil(t, resp) +// require.Equal(t, http.StatusOK, resp.StatusCode) +// require.Equal(t, len(b), int(resp.ContentLength)) +//} diff --git a/libsq/source/files.go b/libsq/source/files.go index f4c080608..f23faf828 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -50,7 +50,6 @@ type Files struct { tempDir string clnup *cleanup.Cleanup optRegistry *options.Registry - coll *Collection // fscache is used to cache files, providing convenient access // to multiple readers via Files.newReader. @@ -68,7 +67,9 @@ type Files struct { // NewFiles returns a new Files instance. If cleanFscache is true, the fscache // is cleaned on Files.Close. -func NewFiles(ctx context.Context, coll *Collection, optReg *options.Registry, tmpDir, cacheDir string, cleanFscache bool) (*Files, error) { +func NewFiles(ctx context.Context, optReg *options.Registry, + tmpDir, cacheDir string, cleanFscache bool, +) (*Files, error) { log := lg.FromContext(ctx) log.Debug("Creating new Files instance", "tmp_dir", tmpDir, "cache_dir", cacheDir) if tmpDir == "" { @@ -83,7 +84,6 @@ func NewFiles(ctx context.Context, coll *Collection, optReg *options.Registry, t } fs := &Files{ - coll: coll, optRegistry: optReg, cacheDir: cacheDir, fscacheEntryMetas: make(map[string]*fscacheEntryMeta), @@ -289,10 +289,22 @@ func (fs *Files) addRegularFile(ctx context.Context, f *os.File, key string) (fs return r, errz.Err(err) } -func (fs *Files) addRemoteFile(ctx context.Context, handle, loc string) (io.ReadCloser, error) { +// addRemoteFile adds a remote file to fs's cache, returning a reader. +// If the remote file is already cached, the path to that cached download +// file is returned in cachedDownload; otherwise cachedDownload is empty. +func (fs *Files) addRemoteFile(ctx context.Context, handle, loc string) (cachedDownload string, + rdr io.ReadCloser, err error, +) { + // FIXME: addRemoteFile should take a source, because we need to look + // at the src's options to create the correctly configured http client. + + if getLocType(loc) != locTypeRemoteFile { + return "", nil, errz.Errorf("not a remote file: %s", loc) + } + dlDir := fs.downloadCacheDirFor(loc) - if err := ioz.RequireDir(dlDir); err != nil { - return nil, err + if err = ioz.RequireDir(dlDir); err != nil { + return "", nil, err } errCh := make(chan error, 1) @@ -300,9 +312,11 @@ func (fs *Files) addRemoteFile(ctx context.Context, handle, loc string) (io.Read h := download.Handler{ Cached: func(fp string) { - if err := fs.fscache.MapFile(fp); err != nil { - errCh <- errz.Wrapf(err, "failed to map file into fscache: %s", fp) - return + if !fs.fscache.Exists(fp) { + if err := fs.fscache.MapFile(fp); err != nil { + errCh <- errz.Wrapf(err, "failed to map file into fscache: %s", fp) + return + } } r, _, err := fs.fscache.Get(fp) @@ -310,18 +324,20 @@ func (fs *Files) addRemoteFile(ctx context.Context, handle, loc string) (io.Read errCh <- errz.Err(err) return } + cachedDownload = fp rdrCh <- r }, Uncached: func() (dest ioz.WriteErrorCloser) { - r, w, err := fs.fscache.Get(loc) + r, w, wErrFn, err := fs.fscache.GetWithErr(loc) if err != nil { errCh <- errz.Err(err) return nil } wec := ioz.NewFuncWriteErrorCloser(w, func(err error) { - log := lg.FromContext(ctx) - lg.WarnIfError(log, "Remove damaged cache entry", fs.fscache.Remove(loc)) + lg.FromContext(ctx).Error("Error writing to fscache", + lga.Handle, handle, lga.URL, loc, lga.Err, err) + wErrFn(err) }) rdrCh <- r @@ -335,37 +351,37 @@ func (fs *Files) addRemoteFile(ctx context.Context, handle, loc string) (io.Read c := httpz.NewDefaultClient() dl, err := download.New(handle, c, loc, dlDir) if err != nil { - return nil, err + return "", nil, err } go dl.Get(ctx, h) select { case <-ctx.Done(): - return nil, errz.Err(ctx.Err()) + return "", nil, errz.Err(ctx.Err()) case err = <-errCh: - return nil, err - case rdr := <-rdrCh: - return rdr, nil + return "", nil, err + case rdr = <-rdrCh: + return cachedDownload, rdr, nil } } -// Filepath returns the file path of src.Location. -// An error is returned the source's driver type -// is not a file type (i.e. it is a SQL driver). -// FIXME: Implement Files.Filepath fully. -func (fs *Files) Filepath(_ context.Context, src *Source) (string, error) { - // fs.mu.Lock() - // defer fs.mu.Unlock() - +// filepath returns the file path of src.Location. An error is returned +// if the source's driver type is not a document type (e.g. it is a +// SQL driver). If src is a remote (http) location, the returned filepath +// is that of the cached download file. If that file is not present, an +// error is returned. +func (fs *Files) filepath(src *Source) (string, error) { switch getLocType(src.Location) { case locTypeLocalFile: return src.Location, nil case locTypeRemoteFile: - // FIXME: implement remote file location. - // It's a remote file. We really should download it here. - // FIXME: implement downloading. - return "", errz.Errorf("not implemented for remote source: %s", src.Handle) + dlDir := fs.downloadCacheDirFor(src.Location) + dlFile := filepath.Join(dlDir, "body") + if !ioz.FileAccessible(dlFile) { + return "", errz.Errorf("remote file for %s not downloaded at: %s", src.Handle, dlFile) + } + return dlFile, nil case locTypeSQL: return "", errz.Errorf("cannot get filepath of SQL source: %s", src.Handle) case locTypeStdin: @@ -409,8 +425,6 @@ func (fs *Files) OpenFunc(src *Source) FileOpenFunc { } func (fs *Files) newReader(ctx context.Context, handle, loc string) (io.ReadCloser, error) { - // log := lg.FromContext(ctx).With(lga.Loc, loc) - locTyp := getLocType(loc) switch locTyp { case locTypeUnknown: @@ -454,73 +468,8 @@ func (fs *Files) newReader(ctx context.Context, handle, loc string) (io.ReadClos return r, nil } - //if loc == StdinHandle { - // r, w, err := fs.fscache.Get(StdinHandle) - // log.Debug("Returned from fs.fcache.Get", lga.Err, err) - // if err != nil { - // return nil, errz.Err(err) - // } - // if w != nil { - // return nil, errz.New("@stdin not cached: has AddStdin been invoked yet?") - // } - // - // return r, nil - //} - - //// It's an uncached remote file. - //if !fs.fscache.Exists(loc) { - // r, _, err := fs.fscache.Get(loc) - // if err != nil { - // return nil, err - // } - // - // return r, nil - //} - // - //// cache miss - //f, err := fs.openLocation(ctx, loc) - //if err != nil { - // return nil, err - //} - // - //// Note that addRegularFile closes f - //r, err := fs.addRegularFile(ctx, f, loc) - //if err != nil { - // return nil, err - //} - //return r, nil - - return fs.addRemoteFile(ctx, handle, loc) -} - -// openLocation returns a file for loc. It is the caller's -// responsibility to close the returned file. -func (fs *Files) openLocation(ctx context.Context, loc string) (*os.File, error) { - var fpath string - var ok bool - var err error - - fpath, ok = isFpath(loc) - if ok { - // we have a legitimate fpath - return errz.Return(os.Open(fpath)) - } - // It's not a local file path, maybe it's remote (http) - var u *url.URL - u, ok = httpURL(loc) - if !ok { - // We're out of luck, it's not a valid file location - return nil, errz.Errorf("invalid src location: %s", loc) - } - - // It's a remote file - fpath, err = fs.fetch(ctx, u.String()) - if err != nil { - return nil, err - } - - f, err := os.Open(fpath) - return f, errz.Err(err) + _, r, err := fs.addRemoteFile(ctx, handle, loc) + return r, err } // Close closes any open resources. diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index 12d65d62e..b6d54c064 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -56,14 +56,7 @@ func TestFiles_Type(t *testing.T) { t.Run(tu.Name(source.RedactLocation(tc.loc)), func(t *testing.T) { ctx := lg.NewContext(context.Background(), lgt.New(t)) - fs, err := source.NewFiles( - ctx, - &source.Collection{}, - nil, - tu.TempDir(t), - tu.CacheDir(t), - true, - ) + fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -105,14 +98,7 @@ func TestFiles_DetectType(t *testing.T) { t.Run(filepath.Base(tc.loc), func(t *testing.T) { ctx := lg.NewContext(context.Background(), lgt.New(t)) - fs, err := source.NewFiles( - ctx, - &source.Collection{}, - nil, - tu.TempDir(t), - tu.CacheDir(t), - true, - ) + fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -172,14 +158,7 @@ func TestFiles_NewReader(t *testing.T) { Location: proj.Abs(fpath), } - fs, err := source.NewFiles( - ctx, - &source.Collection{}, - nil, - tu.TempDir(t), - tu.CacheDir(t), - true, - ) + fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) require.NoError(t, err) g := &errgroup.Group{} diff --git a/libsq/source/internal_test.go b/libsq/source/internal_test.go index 71ffc70c0..b2524a4b9 100644 --- a/libsq/source/internal_test.go +++ b/libsq/source/internal_test.go @@ -2,66 +2,23 @@ package source import ( "context" - "io" "runtime" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/source/drivertype" - "github.com/neilotoole/sq/testh/proj" - "github.com/neilotoole/sq/testh/sakila" - "github.com/neilotoole/sq/testh/testsrc" "github.com/neilotoole/sq/testh/tu" ) // Export for testing. var ( FilesDetectTypeFn = func(fs *Files, ctx context.Context, loc string) (typ drivertype.Type, ok bool, err error) { - return fs.detectType(ctx, "", loc) + return fs.detectType(ctx, "@test", loc) } GroupsFilterOnlyDirectChildren = groupsFilterOnlyDirectChildren ) -func TestFiles_Open(t *testing.T) { - ctx := lg.NewContext(context.Background(), lgt.New(t)) - - fs, err := NewFiles( - ctx, - &Collection{}, - nil, - tu.TempDir(t), - tu.CacheDir(t), - true, - ) - require.NoError(t, err) - t.Cleanup(func() { assert.NoError(t, fs.Close()) }) - - src1 := &Source{ - Location: proj.Abs(testsrc.PathXLSXTestHeader), - } - - f, err := fs.openLocation(ctx, src1.Location) - require.NoError(t, err) - t.Cleanup(func() { assert.NoError(t, f.Close()) }) - require.Equal(t, src1.Location, f.Name()) - - src2 := &Source{ - Location: sakila.URLActorCSV, - } - - f2, err := fs.openLocation(ctx, src2.Location) - require.NoError(t, err) - t.Cleanup(func() { assert.NoError(t, f2.Close()) }) - - b, err := io.ReadAll(f2) - require.NoError(t, err) - require.Equal(t, proj.ReadFile(sakila.PathCSVActor), b) -} - func TestParseLoc(t *testing.T) { const ( dbuser = "sakila" diff --git a/testh/testh.go b/testh/testh.go index ea441f906..801819d93 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -157,14 +157,7 @@ func (h *Helper) init() { cfg := config.New() var err error - h.files, err = source.NewFiles( - h.Context, - cfg.Collection, - optRegistry, - tu.TempDir(h.T), - tu.CacheDir(h.T), - true, - ) + h.files, err = source.NewFiles(h.Context, optRegistry, tu.TempDir(h.T), tu.CacheDir(h.T), true) require.NoError(h.T, err) h.Cleanup.Add(func() { @@ -627,7 +620,7 @@ func (h *Helper) QuerySLQ(query string, args map[string]string) (*RecordSink, er qc := &libsq.QueryContext{ Collection: h.coll, - Sources: h.grips, + Grips: h.grips, Args: args, } From 184106acc2c877533a9e838057396f2775464a80 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Wed, 10 Jan 2024 07:15:22 -0700 Subject: [PATCH 151/195] Downloads largely working --- cli/cmd_x.go | 4 +- cli/options.go | 3 +- drivers/xlsx/ingest.go | 5 +- libsq/core/ioz/download/cache.go | 12 +- libsq/core/ioz/download/download.go | 23 ++- libsq/core/ioz/download/download_test.go | 2 +- libsq/core/ioz/httpz/httpz.go | 14 ++ libsq/core/ioz/httpz/httpz_test.go | 10 +- libsq/core/ioz/httpz/opts.go | 14 +- libsq/core/lg/lga/lga.go | 1 + libsq/driver/record.go | 8 +- libsq/source/cache.go | 127 ++++++++++-- libsq/source/detect.go | 6 +- libsq/source/download.go | 160 ++++++++++++++- libsq/source/download_test.go | 90 --------- libsq/source/fetcher/fetcher.go | 110 ----------- libsq/source/fetcher/fetcher_test.go | 67 ------- libsq/source/files.go | 236 +++-------------------- libsq/source/location.go | 12 ++ 19 files changed, 378 insertions(+), 526 deletions(-) delete mode 100644 libsq/source/fetcher/fetcher.go delete mode 100644 libsq/source/fetcher/fetcher_test.go diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 93a4864f7..53ec33dfd 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -176,8 +176,8 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { c := httpz.NewClient( httpz.DefaultUserAgent, - httpz.OptRequestTimeout(time.Second*15), - // httpz.OptHeaderTimeout(time.Second*2), + httpz.OptResponseTimeout(time.Second*15), + // httpz.OptRequestTimeout(time.Second*2), httpz.OptRequestDelay(time.Second*5), ) dl, err := download.New(fakeSrc.Handle, c, u.String(), cacheDir) diff --git a/cli/options.go b/cli/options.go index 161780df6..5c32dfd3f 100644 --- a/cli/options.go +++ b/cli/options.go @@ -175,7 +175,8 @@ func RegisterDefaultOpts(reg *options.Registry) { OptLogDevMode, OptDiffNumLines, OptDiffDataFormat, - source.OptHTTPPingTimeout, + source.OptHTTPRequestTimeout, + source.OptHTTPResponseTimeout, source.OptHTTPSInsecureSkipVerify, driver.OptConnMaxOpen, driver.OptConnMaxIdle, diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index c68fa42ce..f4ca914c5 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -134,7 +134,10 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, x var ingestCount, skipped int for i := range sheetTbls { - time.Sleep(progress.DebugDelay) + if progress.DebugDelay > 0 { + time.Sleep(progress.DebugDelay) + } + if sheetTbls[i] == nil { // tblDef can be nil if its sheet is empty (has no data). skipped++ diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index 945353dad..396d4c1c2 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -26,10 +26,10 @@ const ( msgCloseCacheBodyFile = "Close cached response body file" ) -// cache is a cache for a individual download. The cached response is +// cache is a cache for an individual download. The cached response is // stored in two files, one for the header and one for the body, with // a checksum (of the body file) stored in a third file. -// Use cache.paths to access the cache files. +// Use cache.paths to get the cache file paths. type cache struct { // FIXME: move the mutex to the Download struct? mu sync.Mutex @@ -44,7 +44,7 @@ func (c *cache) paths(req *http.Request) (header, body, checksum string) { if req == nil || req.Method == http.MethodGet { return filepath.Join(c.dir, "header"), filepath.Join(c.dir, "body"), - filepath.Join(c.dir, "checksum.txt") + filepath.Join(c.dir, "checksums.txt") } // This is probably not strictly necessary because we're always @@ -52,7 +52,7 @@ func (c *cache) paths(req *http.Request) (header, body, checksum string) { // Can probably delete. return filepath.Join(c.dir, req.Method+"_header"), filepath.Join(c.dir, req.Method+"_body"), - filepath.Join(c.dir, req.Method+"_checksum.txt") + filepath.Join(c.dir, req.Method+"_checksums.txt") } // exists returns true if the cache exists and is consistent. @@ -291,7 +291,7 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, return err } - log.Debug("Writing HTTP response to cache", lga.Dir, c.dir, "resp", httpz.ResponseLogValue(resp)) + log.Debug("Writing HTTP response to cache", lga.Dir, c.dir, lga.Resp, httpz.ResponseLogValue(resp)) fpHeader, fpBody, _ := c.paths(resp.Request) headerBytes, err := httputil.DumpResponse(resp, false) @@ -344,7 +344,7 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, return errz.Wrap(err, "failed to compute checksum for cache body file") } - if err = checksum.WriteFile(filepath.Join(c.dir, "checksum.txt"), sum, "body"); err != nil { + if err = checksum.WriteFile(filepath.Join(c.dir, "checksums.txt"), sum, "body"); err != nil { return errz.Wrap(err, "failed to write checksum file for cache body") } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 6227f9d26..a20cedf80 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -134,6 +134,7 @@ func New(name string, c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Dow // Get gets the download, invoking Handler as appropriate. func (dl *Download) Get(ctx context.Context, h Handler) { req := dl.mustRequest(ctx) + lg.FromContext(ctx).Debug("Get download", lga.URL, dl.url) dl.get(req, h) } @@ -142,7 +143,6 @@ func (dl *Download) Get(ctx context.Context, h Handler) { func (dl *Download) get(req *http.Request, h Handler) { //nolint:funlen,gocognit ctx := req.Context() log := lg.FromContext(ctx) - log.Debug("Get download", lga.URL, dl.url) _, fpBody, _ := dl.cache.paths(req) state := dl.state(req) @@ -306,6 +306,7 @@ func (dl *Download) get(req *http.Request, h Handler) { //nolint:funlen,gocognit func (dl *Download) do(req *http.Request) (*http.Response, error) { bar := progress.FromContext(req.Context()).NewWaiter(dl.name+": start download", true) resp, err := dl.c.Do(req) + httpz.Log(req, resp, err) bar.Stop() if err != nil { // Download timeout errors are typically wrapped in an url.Error, resulting @@ -384,6 +385,26 @@ func (dl *Download) state(req *http.Request) State { return getFreshness(cachedResp.Header, req.Header) } +// CacheFile returns the path to the cached file and its size, if it exists +// and has been fully downloaded. +func (dl *Download) CacheFile(ctx context.Context) (fp string, size int64, err error) { + if dl.cache == nil { + return "", 0, errz.Errorf("cache doesn't exist for: %s", dl.url) + } + + req := dl.mustRequest(ctx) + if !dl.cache.exists(req) { + return "", 0, errz.Errorf("no cache for: %s", dl.url) + } + _, fp, _ = dl.cache.paths(req) + + fi, err := os.Stat(fp) + if err != nil { + return "", 0, errz.Err(err) + } + return fp, fi.Size(), nil +} + // Checksum returns the checksum of the cached download, if available. func (dl *Download) Checksum(ctx context.Context) (sum checksum.Checksum, ok bool) { if dl.cache == nil { diff --git a/libsq/core/ioz/download/download_test.go b/libsq/core/ioz/download/download_test.go index a0d7625df..36d28fc20 100644 --- a/libsq/core/ioz/download/download_test.go +++ b/libsq/core/ioz/download/download_test.go @@ -46,7 +46,7 @@ func TestSlowHeaderServer(t *testing.T) { t.Cleanup(srvr.Close) clientHeaderTimeout := time.Second * 2 - c := httpz.NewClient(httpz.OptHeaderTimeout(clientHeaderTimeout)) + c := httpz.NewClient(httpz.OptRequestTimeout(clientHeaderTimeout)) req, err := http.NewRequest(http.MethodGet, srvr.URL, nil) require.NoError(t, err) resp, err := c.Do(req) diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index cece6f3bd..630c9ab1c 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -22,6 +22,9 @@ import ( "strconv" "strings" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/stringz" ) @@ -119,6 +122,17 @@ func ResponseLogValue(resp *http.Response) slog.Value { return slog.GroupValue(attrs...) } +// Log logs req, resp, and err via the logger on req.Context(). +func Log(req *http.Request, resp *http.Response, err error) { + log := lg.FromContext(req.Context()).With(lga.Method, req.Method, lga.URL, req.URL) + if err != nil { + log.Warn("HTTP request error", lga.Err, err) + return + } + + log.Debug("HTTP request completed", lga.Resp, ResponseLogValue(resp)) +} + // RequestLogValue implements slog.LogValuer for req. func RequestLogValue(req *http.Request) slog.Value { if req == nil { diff --git a/libsq/core/ioz/httpz/httpz_test.go b/libsq/core/ioz/httpz/httpz_test.go index 146e000a1..6d030bfc8 100644 --- a/libsq/core/ioz/httpz/httpz_test.go +++ b/libsq/core/ioz/httpz/httpz_test.go @@ -40,7 +40,7 @@ func TestOptRequestTimeout(t *testing.T) { require.NoError(t, err) clientRequestTimeout := time.Millisecond * 100 - c := httpz.NewClient(httpz.OptRequestTimeout(clientRequestTimeout)) + c := httpz.NewClient(httpz.OptResponseTimeout(clientRequestTimeout)) resp, err := c.Do(req) require.Error(t, err) require.Nil(t, resp) @@ -48,7 +48,7 @@ func TestOptRequestTimeout(t *testing.T) { } // TestOptHeaderTimeout_correct_error verifies that an HTTP request -// that fails via OptHeaderTimeout returns the correct error. +// that fails via OptRequestTimeout returns the correct error. func TestOptHeaderTimeout_correct_error(t *testing.T) { t.Parallel() @@ -69,7 +69,7 @@ func TestOptHeaderTimeout_correct_error(t *testing.T) { t.Cleanup(srvr.Close) clientHeaderTimeout := time.Second * 1 - c := httpz.NewClient(httpz.OptHeaderTimeout(clientHeaderTimeout)) + c := httpz.NewClient(httpz.OptRequestTimeout(clientHeaderTimeout)) req, err := http.NewRequestWithContext(ctx, http.MethodGet, srvr.URL, nil) require.NoError(t, err) @@ -90,7 +90,7 @@ func TestOptHeaderTimeout_correct_error(t *testing.T) { require.Equal(t, srvrBody, got) } -// TestOptHeaderTimeout_vs_stdlib verifies that OptHeaderTimeout +// TestOptHeaderTimeout_vs_stdlib verifies that OptRequestTimeout // works as expected when compared to stdlib. func TestOptHeaderTimeout_vs_stdlib(t *testing.T) { t.Parallel() @@ -121,7 +121,7 @@ func TestOptHeaderTimeout_vs_stdlib(t *testing.T) { ctxFn: func(t *testing.T) context.Context { return context.Background() }, - c: httpz.NewClient(httpz.OptHeaderTimeout(headerTimeout)), + c: httpz.NewClient(httpz.OptRequestTimeout(headerTimeout)), wantErr: false, }, } diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index 9f042f8e8..53e85e820 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -77,12 +77,12 @@ func contextCause() TripFunc { // as used by [NewDefaultClient]. var DefaultUserAgent = OptUserAgent(buildinfo.Get().UserAgent()) -// OptRequestTimeout is passed to [NewClient] to set the total request timeout, +// OptResponseTimeout is passed to [NewClient] to set the total request timeout, // including reading the body. This is basically the same as a traditional // request timeout via context.WithTimeout. If timeout is zero, this is no-op. // -// Contrast with [OptHeaderTimeout]. -func OptRequestTimeout(timeout time.Duration) TripFunc { +// Contrast with [OptRequestTimeout]. +func OptResponseTimeout(timeout time.Duration) TripFunc { if timeout <= 0 { return NopTripFunc } @@ -127,13 +127,13 @@ func OptRequestTimeout(timeout time.Duration) TripFunc { } } -// OptHeaderTimeout is passed to [NewClient] to set a timeout for just +// OptRequestTimeout is passed to [NewClient] to set a timeout for just // getting the initial response headers. This is useful if you expect // a response within, say, 2 seconds, but you expect the body to take longer // to read. // -// Contrast with [OptRequestTimeout]. -func OptHeaderTimeout(timeout time.Duration) TripFunc { +// Contrast with [OptResponseTimeout]. +func OptRequestTimeout(timeout time.Duration) TripFunc { if timeout <= 0 { return NopTripFunc } @@ -196,7 +196,7 @@ func OptHeaderTimeout(timeout time.Duration) TripFunc { // DefaultHeaderTimeout is the default header timeout as used // by [NewDefaultClient]. -var DefaultHeaderTimeout = OptHeaderTimeout(time.Second * 5) +var DefaultHeaderTimeout = OptRequestTimeout(time.Second * 5) // OptRequestDelay is passed to [NewClient] to delay the request by the // specified duration. This is useful for testing. diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index 33a9a2544..610ac1e87 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -43,6 +43,7 @@ const ( Path = "path" Pid = "pid" Query = "query" + Resp = "resp" Score = "score" Size = "size" SLQ = "slq" diff --git a/libsq/driver/record.go b/libsq/driver/record.go index 67da2d2d8..2989065f5 100644 --- a/libsq/driver/record.go +++ b/libsq/driver/record.go @@ -468,7 +468,9 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, bi.written.Add(affected) pbar.IncrBy(int(affected)) - time.Sleep(progress.DebugDelay) + if progress.DebugDelay > 0 { + time.Sleep(progress.DebugDelay) + } if rec == nil { // recCh is closed (coincidentally exactly on the @@ -512,7 +514,9 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, bi.written.Add(affected) pbar.IncrBy(int(affected)) - time.Sleep(progress.DebugDelay) + if progress.DebugDelay > 0 { + time.Sleep(progress.DebugDelay) + } // We're done return diff --git a/libsq/source/cache.go b/libsq/source/cache.go index e7031ec9b..751f2c70d 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -9,6 +9,8 @@ import ( "strings" "time" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" + "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/ioz" @@ -61,13 +63,6 @@ func (fs *Files) CacheDirFor(src *Source) (dir string, err error) { return dir, nil } -// downloadCacheDirFor gets the download cache dir for loc. It is not guaranteed -// that the returned dir exists or is accessible. -func (fs *Files) downloadCacheDirFor(loc string) (dir string) { - fp := filepath.Join(fs.cacheDir, "downloads", checksum.Sum([]byte(loc))) - return fp -} - func (fs *Files) WriteIngestChecksum(ctx context.Context, src, backingSrc *Source) (err error) { log := lg.FromContext(ctx) ingestFilePath, err := fs.filepath(src) @@ -103,7 +98,7 @@ func (fs *Files) CachedBackingSourceFor(ctx context.Context, src *Source) (backi switch getLocType(src.Location) { case locTypeLocalFile: - return fs.cachedBackingSourceForLocalFile(ctx, src) + return fs.cachedBackingSourceForFile(ctx, src) case locTypeRemoteFile: return fs.cachedBackingSourceForRemoteFile(ctx, src) default: @@ -111,9 +106,9 @@ func (fs *Files) CachedBackingSourceFor(ctx context.Context, src *Source) (backi } } -// cachedBackingSourceForLocalFile returns the underlying cached backing +// cachedBackingSourceForFile returns the underlying cached backing // source for src, if it exists. -func (fs *Files) cachedBackingSourceForLocalFile(ctx context.Context, src *Source) (*Source, bool, error) { +func (fs *Files) cachedBackingSourceForFile(ctx context.Context, src *Source) (*Source, bool, error) { _, cacheDBPath, checksumsPath, err := fs.CachePaths(src) if err != nil { return nil, false, err @@ -166,13 +161,14 @@ func (fs *Files) cachedBackingSourceForLocalFile(ctx context.Context, src *Sourc // cachedBackingSourceForRemoteFile returns the underlying cached backing // source for src, if it exists. func (fs *Files) cachedBackingSourceForRemoteFile(ctx context.Context, src *Source) (*Source, bool, error) { - // src.Location is guaranteed to be a URL. log := lg.FromContext(ctx) - downloadedFile, r, err := fs.addRemoteFile(ctx, src.Handle, src.Location) + downloadedFile, r, err := fs.openRemoteFile(ctx, src, true) if err != nil { return nil, false, err } + + // We don't care about the reader, but we do need to close it. lg.WarnIfCloseError(log, lgm.CloseFileReader, r) if downloadedFile == "" { log.Debug("No cached download file for src", lga.Src, src) @@ -180,7 +176,7 @@ func (fs *Files) cachedBackingSourceForRemoteFile(ctx context.Context, src *Sour } log.Debug("Found cached download file for src", lga.Src, src, lga.Path, downloadedFile) - return fs.cachedBackingSourceForLocalFile(ctx, src) + return fs.cachedBackingSourceForFile(ctx, src) } // CachePaths returns the paths to the cache files for src. @@ -238,6 +234,111 @@ func (fs *Files) sourceHash(src *Source) string { return sum } +// CacheLockFor returns the lock file for src's cache. +func (fs *Files) CacheLockFor(src *Source) (lockfile.Lockfile, error) { + cacheDir, err := fs.CacheDirFor(src) + if err != nil { + return "", errz.Wrapf(err, "cache lock for %s", src.Handle) + } + + lf, err := lockfile.New(filepath.Join(cacheDir, "pid.lock")) + if err != nil { + return "", errz.Wrapf(err, "cache lock for %s", src.Handle) + } + + return lf, nil +} + +// CacheClear clears the cache dir. This wipes the entire contents +// of the cache dir, so it should be used with caution. Note that +// this operation is distinct from [Files.CacheSweep]. +func (fs *Files) CacheClear(ctx context.Context) error { + fs.mu.Lock() + defer fs.mu.Unlock() + + log := lg.FromContext(ctx).With(lga.Dir, fs.cacheDir) + log.Debug("Clearing cache dir") + if !ioz.DirExists(fs.cacheDir) { + log.Debug("Cache dir does not exist") + return nil + } + + // Instead of directly deleting the existing cache dir, we first + // move it to /tmp, and then try to delete it. This should probably + // help with the situation where another sq instance has an open pid + // lock in the cache dir. + + tmpDir := DefaultTempDir() + if err := ioz.RequireDir(tmpDir); err != nil { + return errz.Wrap(err, "cache clear") + } + relocateDir := filepath.Join(tmpDir, "dead_cache_"+stringz.Uniq8()) + if err := os.Rename(fs.cacheDir, relocateDir); err != nil { + return errz.Wrap(err, "cache clear: relocate") + } + + if err := os.RemoveAll(relocateDir); err != nil { + log.Warn("Could not delete relocated cache dir", lga.Path, relocateDir, lga.Err, err) + } + + // Recreate the cache dir. + if err := ioz.RequireDir(fs.cacheDir); err != nil { + return errz.Wrap(err, "cache clear") + } + + return nil +} + +// CacheSweep sweeps the cache dir, making a best-effort attempt +// to remove any empty directories. Note that this operation is +// distinct from [Files.CacheClear]. +func (fs *Files) CacheSweep(ctx context.Context) { + fs.mu.Lock() + defer fs.mu.Unlock() + + dir := fs.cacheDir + log := lg.FromContext(ctx).With(lga.Dir, dir) + log.Debug("Sweeping cache dir") + var count int + err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if err != nil { + log.Warn("Problem sweeping cache dir", lga.Path, path, lga.Err, err) + return nil + } + + if !info.IsDir() { + return nil + } + + files, err := os.ReadDir(path) + if err != nil { + log.Warn("Problem reading dir", lga.Dir, path, lga.Err, err) + return nil + } + + if len(files) != 0 { + return nil + } + + err = os.Remove(path) + if err != nil { + log.Warn("Problem removing empty dir", lga.Dir, path, lga.Err, err) + } + count++ + + return nil + }) + if err != nil { + log.Warn("Problem sweeping cache dir", lga.Dir, dir, lga.Err, err) + } + log.Info("Swept cache dir", lga.Dir, dir, lga.Count, count) +} + // DefaultCacheDir returns the sq cache dir. This is generally // in USER_CACHE_DIR/*/sq, but could also be in TEMP_DIR/*/sq/cache // or similar. It is not guaranteed that the returned dir exists diff --git a/libsq/source/detect.go b/libsq/source/detect.go index 1a5f01140..1c6422529 100644 --- a/libsq/source/detect.go +++ b/libsq/source/detect.go @@ -79,6 +79,9 @@ func (fs *Files) DriverType(ctx context.Context, handle, loc string) (drivertype return typ, nil } +// detectType detects the type of src's location. The value of Source.Type +// is ignored. If the type cannot be detected, drivertype.None and false are +// returned. func (fs *Files) detectType(ctx context.Context, handle, loc string) (typ drivertype.Type, ok bool, err error) { if len(fs.detectFns) == 0 { return drivertype.None, false, nil @@ -87,7 +90,8 @@ func (fs *Files) detectType(ctx context.Context, handle, loc string) (typ driver start := time.Now() openFn := func(ctx context.Context) (io.ReadCloser, error) { - return fs.newReader(ctx, handle, loc) + src := &Source{Handle: handle, Location: loc} + return fs.newReader(ctx, src) } // We do the magic number first, because it's so fast. diff --git a/libsq/source/download.go b/libsq/source/download.go index cbd215865..99b78d365 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -1,26 +1,48 @@ package source import ( + "context" + "io" + "path/filepath" "time" + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/ioz/download" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/options" ) -var OptHTTPPingTimeout = options.NewDuration( - // FIXME: apply OptHTTPPingTimeout to httpz.NewClient invocations - "http.ping.timeout", +var OptHTTPRequestTimeout = options.NewDuration( + "http.request.timeout", "", 0, time.Second*10, - "HTTP/S ping timeout duration", + "HTTP/S request initial response timeout duration", `How long to wait for initial response from HTTP/S endpoint before -timeout occurs. Long-running operations, such as HTTP file downloads, are -not affected by this option. Example: 500ms or 3s.`, +timeout occurs. Reading the body of the response, such as large HTTP file +downloads, is not affected by this option. Example: 500ms or 3s. +Contrast with http.response.timeout.`, + options.TagSource, +) + +var OptHTTPResponseTimeout = options.NewDuration( + "http.response.timeout", + "", + 0, + 0, + "HTTP/S response completion timeout duration", + `How long to wait for the entire HTTP transaction to complete. This includes +reading the body of the response, such as large HTTP file downloads. Typically +this is set to 0, indicating no timeout. Contrast with http.request.timeout.`, options.TagSource, ) var OptHTTPSInsecureSkipVerify = options.NewBool( - // FIXME: apply OptHTTPSInsecureSkipVerify to httpz.NewClient invocations "https.insecure-skip-verify", "", false, @@ -29,3 +51,127 @@ var OptHTTPSInsecureSkipVerify = options.NewBool( "Skip HTTPS TLS verification", "Skip HTTPS TLS verification. Useful when downloading against self-signed certs.", ) + +func (fs *Files) downloadFor(ctx context.Context, src *Source) (*download.Download, error) { + // REVISIT: should downloadFor return a cached instance of download.Download? + + dlDir, err := fs.downloadDirFor(src) + if err != nil { + return nil, err + } + if err = ioz.RequireDir(dlDir); err != nil { + return nil, err + } + + o := options.Merge(options.FromContext(ctx), src.Options) + c := httpz.NewClient(httpz.DefaultUserAgent, + httpz.OptRequestTimeout(OptHTTPRequestTimeout.Get(o)), + httpz.OptResponseTimeout(OptHTTPResponseTimeout.Get(o)), + httpz.OptInsecureSkipVerify(OptHTTPSInsecureSkipVerify.Get(o)), + ) + + dl, err := download.New(src.Handle, c, src.Location, dlDir) + if err != nil { + return nil, err + } + + return dl, nil +} + +// downloadDirFor gets the download cache dir for src. It is not +// guaranteed that the returned dir exists or is accessible. +func (fs *Files) downloadDirFor(src *Source) (string, error) { + cacheDir, err := fs.CacheDirFor(src) + if err != nil { + return "", err + } + + fp := filepath.Join(cacheDir, "download", checksum.Sum([]byte(src.Location))) + return fp, nil +} + +// openRemoteFile adds a remote file to fs's cache, returning a reader. +// If the remote file is already cached, the path to that cached download +// file is returned in cachedDownload; otherwise cachedDownload is empty. +// If checkFresh is false and the file is already fully downloaded, its +// freshness is not checked against the remote server. +func (fs *Files) openRemoteFile(ctx context.Context, src *Source, checkFresh bool) (cachedDownload string, + rdr io.ReadCloser, err error, +) { + loc := src.Location + if getLocType(loc) != locTypeRemoteFile { + return "", nil, errz.Errorf("not a remote file: %s", loc) + } + + dl, err := fs.downloadFor(ctx, src) + if err != nil { + return "", nil, err + } + + if !checkFresh && fs.fscache.Exists(loc) { + // If the download has completed, dl.CacheFile will return the + // path to the cached file. + cachedDownload, _, err = dl.CacheFile(ctx) + if err != nil { + return "", nil, err + } + // The file is already cached, and we're not checking freshness. + // So, we can just return the cached reader. + rdr, _, err = fs.fscache.Get(loc) + if err != nil { + return "", nil, errz.Err(err) + } + return cachedDownload, rdr, nil + } + + errCh := make(chan error, 1) + rdrCh := make(chan io.ReadCloser, 1) + + h := download.Handler{ + Cached: func(fp string) { + if !fs.fscache.Exists(fp) { + if hErr := fs.fscache.MapFile(fp); hErr != nil { + errCh <- errz.Wrapf(hErr, "failed to map file into fscache: %s", fp) + return + } + } + + r, _, hErr := fs.fscache.Get(fp) + if hErr != nil { + errCh <- errz.Err(hErr) + return + } + cachedDownload = fp + rdrCh <- r + }, + Uncached: func() (dest ioz.WriteErrorCloser) { + r, w, wErrFn, hErr := fs.fscache.GetWithErr(loc) + if hErr != nil { + errCh <- errz.Err(hErr) + return nil + } + + wec := ioz.NewFuncWriteErrorCloser(w, func(err error) { + lg.FromContext(ctx).Error("Error writing to fscache", lga.Src, src, lga.Err, err) + wErrFn(err) + }) + + rdrCh <- r + return wec + }, + Error: func(hErr error) { + errCh <- hErr + }, + } + + go dl.Get(ctx, h) + + select { + case <-ctx.Done(): + return "", nil, errz.Err(ctx.Err()) + case err = <-errCh: + return "", nil, err + case rdr = <-rdrCh: + return cachedDownload, rdr, nil + } +} diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go index 13a518ba3..1b62a5d77 100644 --- a/libsq/source/download_test.go +++ b/libsq/source/download_test.go @@ -1,98 +1,8 @@ package source -//func TestGetRemoteChecksum(t *testing.T) { -// // sq add https://sq.io/testdata/actor.csv -// // -// // content-length: 7641 -// // date: Thu, 07 Dec 2023 06:31:10 GMT -// // etag: "069dbf690a12d5eb853feb8e04aeb49e-ssl" -// -// // TODO -//} - const ( urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" urlActorCSV = "https://sq.io/testdata/actor.csv" sizeActorCSV = int64(7641) sizeGzipActorCSV = int64(1968) ) - -// -//func TestFetchHTTPHeader_sqio(t *testing.T) { -// header, err := fetchHTTPResponse(context.Background(), http.DefaultClient, urlActorCSV) -// require.NoError(t, err) -// require.NotNil(t, header) -// -// // TODO -//} -// -//func TestDownloader_Download(t *testing.T) { -// ctx := lg.NewContext(context.Background(), lgt.New(t)) -// const dlURL = urlActorCSV -// const wantContentLength = sizeActorCSV -// u, err := url.Parse(dlURL) -// require.NoError(t, err) -// wantFilename := path.Base(u.Path) -// require.Equal(t, "actor.csv", wantFilename) -// -// cacheDir, err := filepath.Abs(filepath.Join("testdata", "downloader", "cache-dir-1")) -// require.NoError(t, err) -// t.Logf("cacheDir: %s", cacheDir) -// -// dl := newDownloader(http.DefaultClient, cacheDir, dlURL) -// require.NoError(t, dl.ClearCache(ctx)) -// -// buf := &bytes.Buffer{} -// written, cachedFp, err := dl.Download(ctx, buf) -// require.NoError(t, err) -// require.FileExists(t, cachedFp) -// require.Equal(t, wantContentLength, written) -// require.Equal(t, wantContentLength, int64(buf.Len())) -// -// s := tu.ReadFileToString(t, dl.headerFile()) -// t.Logf("header.txt\n\n" + s + "\n") -// -// s = tu.ReadFileToString(t, dl.checksumFile()) -// t.Logf("checksum.txt\n\n" + s + "\n") -// -// gotSums, err := checksum.ReadFile(dl.checksumFile()) -// require.NoError(t, err) -// -// isCached, cachedSum, cachedFp := dl.Cached(ctx) -// require.True(t, isCached) -// wantKey := filepath.Join("dl", wantFilename) -// wantFp, err := filepath.Abs(filepath.Join(dl.cacheDir, wantKey)) -// require.NoError(t, err) -// require.Equal(t, wantFp, cachedFp) -// fileSum, ok := gotSums[wantKey] -// require.True(t, ok) -// require.Equal(t, cachedSum, fileSum) -// -// isCurrent, err := dl.CachedIsCurrent(ctx) -// require.NoError(t, err) -// require.True(t, isCurrent) -//} -// -//func TestFetchHTTPHeader_HEAD_fallback_GET(t *testing.T) { -// b := proj.ReadFile("drivers/csv/testdata/sakila-csv/actor.csv") -// srvr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// if r.Method != http.MethodGet { -// w.WriteHeader(http.StatusMethodNotAllowed) -// return -// } -// -// w.Header().Set("Content-Length", strconv.Itoa(len(b))) -// w.WriteHeader(http.StatusOK) -// _, err := w.Write(b) -// require.NoError(t, err) -// })) -// t.Cleanup(srvr.Close) -// -// u := srvr.URL -// -// resp, err := fetchHTTPResponse(context.Background(), http.DefaultClient, u) -// assert.NoError(t, err) -// assert.NotNil(t, resp) -// require.Equal(t, http.StatusOK, resp.StatusCode) -// require.Equal(t, len(b), int(resp.ContentLength)) -//} diff --git a/libsq/source/fetcher/fetcher.go b/libsq/source/fetcher/fetcher.go deleted file mode 100644 index 58f1a1266..000000000 --- a/libsq/source/fetcher/fetcher.go +++ /dev/null @@ -1,110 +0,0 @@ -// Package fetcher provides a mechanism for fetching files -// from URLs. -package fetcher - -import ( - "context" - "crypto/tls" - "io" - "net/http" - "net/url" - "time" - - "golang.org/x/net/context/ctxhttp" - - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/stringz" -) - -// Config parameterizes Fetcher behavior. -type Config struct { - // Timeout is the request timeout. - Timeout time.Duration - - // Skip verification of insecure transports. - InsecureSkipVerify bool -} - -// Fetcher can fetch files from URLs. If field Config is nil, -// defaults are used. At this time, only HTTP/HTTPS is supported, -// but it's possible other schemes (such as FTP) will be -// supported in future. -type Fetcher struct { - Config *Config -} - -// Fetch writes the body of the document at fileURL to w. -func (f *Fetcher) Fetch(ctx context.Context, fileURL string, w io.Writer) error { - return fetchHTTP(ctx, f.Config, fileURL, w) -} - -func httpClient(cfg *Config) *http.Client { - client := *http.DefaultClient - - var tr *http.Transport - if client.Transport == nil { - tr = (http.DefaultTransport.(*http.Transport)).Clone() - } else { - tr = (client.Transport.(*http.Transport)).Clone() - } - - if tr.TLSClientConfig == nil { - tr.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12} - } else { - tr.TLSClientConfig = tr.TLSClientConfig.Clone() - } - - if cfg != nil { - tr.TLSClientConfig.InsecureSkipVerify = cfg.InsecureSkipVerify - client.Timeout = cfg.Timeout - } - - client.Transport = tr - - return &client -} - -func fetchHTTP(ctx context.Context, cfg *Config, fileURL string, w io.Writer) error { - c := httpClient(cfg) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, fileURL, nil) - if err != nil { - return err - } - - resp, err := ctxhttp.Do(ctx, c, req) - if err != nil { - return err - } - - if resp.StatusCode != http.StatusOK { - _ = resp.Body.Close() - return errz.Errorf("http: returned non-200 status code (%s) from: %s", resp.Status, fileURL) - } - - _, err = io.Copy(w, resp.Body) - if err != nil { - _ = resp.Body.Close() - return errz.Wrapf(err, "http: failed to read body from: %s", fileURL) - } - - return errz.Err(resp.Body.Close()) -} - -// Schemes is the set of supported schemes. -func (f *Fetcher) Schemes() []string { - return []string{"http", "https"} -} - -// Supported returns true if loc is a supported URL. -func (f *Fetcher) Supported(loc string) bool { - u, err := url.ParseRequestURI(loc) - if err != nil { - return false - } - - if stringz.InSlice(f.Schemes(), u.Scheme) { - return true - } - - return false -} diff --git a/libsq/source/fetcher/fetcher_test.go b/libsq/source/fetcher/fetcher_test.go deleted file mode 100644 index b0e863767..000000000 --- a/libsq/source/fetcher/fetcher_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package fetcher_test - -import ( - "bytes" - "context" - "io" - "log" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/neilotoole/sq/libsq/source/fetcher" - "github.com/neilotoole/sq/testh/proj" - "github.com/neilotoole/sq/testh/sakila" -) - -func TestFetcherHTTP(t *testing.T) { - wantData := proj.ReadFile(sakila.PathCSVActor) - buf := &bytes.Buffer{} - - f := &fetcher.Fetcher{} - err := f.Fetch(context.Background(), sakila.URLActorCSV, buf) - require.NoError(t, err) - - require.Equal(t, wantData, buf.Bytes()) -} - -func TestFetcherConfig(t *testing.T) { - ctx := context.Background() - serverSleepy := new(time.Duration) - - server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(*serverSleepy) - })) - server.Config.ErrorLog = log.New(io.Discard, "", 0) // hush the server logging - server.StartTLS() - defer server.Close() - - fetchr := &fetcher.Fetcher{} - // No config, expect error because of bad cert - err := fetchr.Fetch(ctx, server.URL, io.Discard) - require.Error(t, err, "expect untrusted cert error") - - cfg := &fetcher.Config{InsecureSkipVerify: true} - - // Config as field of Fetcher - fetchr = &fetcher.Fetcher{Config: cfg} - err = fetchr.Fetch(ctx, server.URL, io.Discard) - require.NoError(t, err) - - // Test timeout - cfg.Timeout = time.Millisecond * 100 - - // Have the server sleep for longer than that - *serverSleepy = time.Millisecond * 200 - fetchr = &fetcher.Fetcher{Config: cfg} - err = fetchr.Fetch(ctx, server.URL, io.Discard) - require.Error(t, err, "should have seen a client timeout") - - // Make the client timeout larger than server sleep time - cfg.Timeout = time.Millisecond * 500 - err = fetchr.Fetch(ctx, server.URL, io.Discard) - require.NoError(t, err) -} diff --git a/libsq/source/files.go b/libsq/source/files.go index f23faf828..fac3b4eac 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -4,31 +4,25 @@ import ( "context" "io" "log/slog" - "net/url" "os" "path/filepath" "strconv" - "strings" "sync" "time" "github.com/neilotoole/sq/libsq/core/ioz/download" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/fscache" - "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/libsq/core/stringz" ) // Files is the centralized API for interacting with files. @@ -127,6 +121,8 @@ func NewFiles(ctx context.Context, optReg *options.Registry, // Filesize returns the file size of src.Location. If the source is being // loaded asynchronously, this function may block until loading completes. // An error is returned if src is not a document/file source. +// For remote files, this method should only be invoked after the file +// has completed downloading, or an error will be returned. func (fs *Files) Filesize(ctx context.Context, src *Source) (size int64, err error) { locTyp := getLocType(src.Location) switch locTyp { @@ -139,11 +135,18 @@ func (fs *Files) Filesize(ctx context.Context, src *Source) (size int64, err err return fi.Size(), nil case locTypeRemoteFile: - // FIXME: implement remote file size. - return 0, errz.Errorf("remote file size not implemented: %s", src.Location) + fs.mu.Lock() + defer fs.mu.Unlock() + var dl *download.Download + if dl, err = fs.downloadFor(ctx, src); err != nil { + return 0, err + } + + _, size, err = dl.CacheFile(ctx) + return size, err case locTypeSQL: - return 0, errz.Errorf("cannot get size of SQL source: %s", src.Handle) + return 0, errz.Errorf("invalid to get size of SQL source: %s", src.Handle) case locTypeStdin: fs.mu.Lock() @@ -265,10 +268,6 @@ func (fs *Files) addRegularFile(ctx context.Context, f *os.File, key string) (fs log := lg.FromContext(ctx) log.Debug("Adding regular file", lga.Key, key, lga.Path, f.Name()) - if strings.Contains(f.Name(), "cached.db") { - log.Error("oh no, shouldn't be happening") // FIXME: delete this - } - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, f) if key == StdinHandle { @@ -289,83 +288,6 @@ func (fs *Files) addRegularFile(ctx context.Context, f *os.File, key string) (fs return r, errz.Err(err) } -// addRemoteFile adds a remote file to fs's cache, returning a reader. -// If the remote file is already cached, the path to that cached download -// file is returned in cachedDownload; otherwise cachedDownload is empty. -func (fs *Files) addRemoteFile(ctx context.Context, handle, loc string) (cachedDownload string, - rdr io.ReadCloser, err error, -) { - // FIXME: addRemoteFile should take a source, because we need to look - // at the src's options to create the correctly configured http client. - - if getLocType(loc) != locTypeRemoteFile { - return "", nil, errz.Errorf("not a remote file: %s", loc) - } - - dlDir := fs.downloadCacheDirFor(loc) - if err = ioz.RequireDir(dlDir); err != nil { - return "", nil, err - } - - errCh := make(chan error, 1) - rdrCh := make(chan io.ReadCloser, 1) - - h := download.Handler{ - Cached: func(fp string) { - if !fs.fscache.Exists(fp) { - if err := fs.fscache.MapFile(fp); err != nil { - errCh <- errz.Wrapf(err, "failed to map file into fscache: %s", fp) - return - } - } - - r, _, err := fs.fscache.Get(fp) - if err != nil { - errCh <- errz.Err(err) - return - } - cachedDownload = fp - rdrCh <- r - }, - Uncached: func() (dest ioz.WriteErrorCloser) { - r, w, wErrFn, err := fs.fscache.GetWithErr(loc) - if err != nil { - errCh <- errz.Err(err) - return nil - } - - wec := ioz.NewFuncWriteErrorCloser(w, func(err error) { - lg.FromContext(ctx).Error("Error writing to fscache", - lga.Handle, handle, lga.URL, loc, lga.Err, err) - wErrFn(err) - }) - - rdrCh <- r - return wec - }, - Error: func(err error) { - errCh <- err - }, - } - - c := httpz.NewDefaultClient() - dl, err := download.New(handle, c, loc, dlDir) - if err != nil { - return "", nil, err - } - - go dl.Get(ctx, h) - - select { - case <-ctx.Done(): - return "", nil, errz.Err(ctx.Err()) - case err = <-errCh: - return "", nil, err - case rdr = <-rdrCh: - return cachedDownload, rdr, nil - } -} - // filepath returns the file path of src.Location. An error is returned // if the source's driver type is not a document type (e.g. it is a // SQL driver). If src is a remote (http) location, the returned filepath @@ -376,7 +298,10 @@ func (fs *Files) filepath(src *Source) (string, error) { case locTypeLocalFile: return src.Location, nil case locTypeRemoteFile: - dlDir := fs.downloadCacheDirFor(src.Location) + dlDir, err := fs.downloadDirFor(src) + if err != nil { + return "", err + } dlFile := filepath.Join(dlDir, "body") if !ioz.FileAccessible(dlFile) { return "", errz.Errorf("remote file for %s not downloaded at: %s", src.Handle, dlFile) @@ -399,22 +324,7 @@ func (fs *Files) Open(ctx context.Context, src *Source) (io.ReadCloser, error) { defer fs.mu.Unlock() lg.FromContext(ctx).Debug("Files.Open", lga.Src, src) - return fs.newReader(ctx, src.Handle, src.Location) -} - -// CacheLockFor returns the lock file for src's cache. -func (fs *Files) CacheLockFor(src *Source) (lockfile.Lockfile, error) { - cacheDir, err := fs.CacheDirFor(src) - if err != nil { - return "", errz.Wrapf(err, "cache lock for %s", src.Handle) - } - - lf, err := lockfile.New(filepath.Join(cacheDir, "pid.lock")) - if err != nil { - return "", errz.Wrapf(err, "cache lock for %s", src.Handle) - } - - return lf, nil + return fs.newReader(ctx, src) } // OpenFunc returns a func that invokes fs.Open for src.Location. @@ -424,13 +334,14 @@ func (fs *Files) OpenFunc(src *Source) FileOpenFunc { } } -func (fs *Files) newReader(ctx context.Context, handle, loc string) (io.ReadCloser, error) { +func (fs *Files) newReader(ctx context.Context, src *Source) (io.ReadCloser, error) { + loc := src.Location locTyp := getLocType(loc) switch locTyp { case locTypeUnknown: return nil, errz.Errorf("unknown source location type: %s", loc) case locTypeSQL: - return nil, errz.Errorf("cannot read SQL source: %s", loc) + return nil, errz.Errorf("invalid to read SQL source: %s", loc) case locTypeStdin: r, w, err := fs.fscache.Get(StdinHandle) if err != nil { @@ -468,7 +379,7 @@ func (fs *Files) newReader(ctx context.Context, handle, loc string) (io.ReadClos return r, nil } - _, r, err := fs.addRemoteFile(ctx, handle, loc) + _, r, err := fs.openRemoteFile(ctx, src, false) return r, err } @@ -479,112 +390,13 @@ func (fs *Files) Close() error { } // CleanupE adds fn to the cleanup sequence invoked by fs.Close. +// +// REVISIT: This CleanupE method really is an odd fish. It's only used +// by the test helper. Probably it can we removed? func (fs *Files) CleanupE(fn func() error) { fs.clnup.AddE(fn) } -// CacheClear clears the cache dir. This wipes the entire contents -// of the cache dir, so it should be used with caution. Note that -// this operation is distinct from [Files.CacheSweep]. -func (fs *Files) CacheClear(ctx context.Context) error { - fs.mu.Lock() - defer fs.mu.Unlock() - - log := lg.FromContext(ctx).With(lga.Dir, fs.cacheDir) - log.Debug("Clearing cache dir") - if !ioz.DirExists(fs.cacheDir) { - log.Debug("Cache dir does not exist") - return nil - } - - // Instead of directly deleting the existing cache dir, we first - // move it to /tmp, and then try to delete it. This should probably - // help with the situation where another sq instance has an open pid - // lock in the cache dir. - - tmpDir := DefaultTempDir() - if err := ioz.RequireDir(tmpDir); err != nil { - return errz.Wrap(err, "cache clear") - } - relocateDir := filepath.Join(tmpDir, "dead_cache_"+stringz.Uniq8()) - if err := os.Rename(fs.cacheDir, relocateDir); err != nil { - return errz.Wrap(err, "cache clear: relocate") - } - - if err := os.RemoveAll(relocateDir); err != nil { - log.Warn("Could not delete relocated cache dir", lga.Path, relocateDir, lga.Err, err) - } - - // Recreate the cache dir. - if err := ioz.RequireDir(fs.cacheDir); err != nil { - return errz.Wrap(err, "cache clear") - } - - return nil -} - -// CacheSweep sweeps the cache dir, making a best-effort attempt -// to remove any empty directories. Note that this operation is -// distinct from [Files.CacheClear]. -func (fs *Files) CacheSweep(ctx context.Context) { - fs.mu.Lock() - defer fs.mu.Unlock() - - dir := fs.cacheDir - log := lg.FromContext(ctx).With(lga.Dir, dir) - log.Debug("Sweeping cache dir") - var count int - err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - if err != nil { - log.Warn("Problem sweeping cache dir", lga.Path, path, lga.Err, err) - return nil - } - - if !info.IsDir() { - return nil - } - - files, err := os.ReadDir(path) - if err != nil { - log.Warn("Problem reading dir", lga.Dir, path, lga.Err, err) - return nil - } - - if len(files) != 0 { - return nil - } - - err = os.Remove(path) - if err != nil { - log.Warn("Problem removing empty dir", lga.Dir, path, lga.Err, err) - } - count++ - - return nil - }) - if err != nil { - log.Warn("Problem sweeping cache dir", lga.Dir, dir, lga.Err, err) - } - log.Info("Swept cache dir", lga.Dir, dir, lga.Count, count) -} - // FileOpenFunc returns a func that opens a ReadCloser. The caller // is responsible for closing the returned ReadCloser. type FileOpenFunc func(ctx context.Context) (io.ReadCloser, error) - -// httpURL tests if s is a well-structured HTTP or HTTPS url, and -// if so, returns the url and true. -func httpURL(s string) (u *url.URL, ok bool) { - var err error - u, err = url.Parse(s) - if err != nil || u.Host == "" || !(u.Scheme == "http" || u.Scheme == "https") { - return nil, false - } - - return u, true -} diff --git a/libsq/source/location.go b/libsq/source/location.go index f5eff4cfb..5802a76e9 100644 --- a/libsq/source/location.go +++ b/libsq/source/location.go @@ -353,3 +353,15 @@ func getLocType(loc string) locType { } return locTypeLocalFile } + +// httpURL tests if s is a well-structured HTTP or HTTPS url, and +// if so, returns the url and true. +func httpURL(s string) (u *url.URL, ok bool) { + var err error + u, err = url.Parse(s) + if err != nil || u.Host == "" || !(u.Scheme == "http" || u.Scheme == "https") { + return nil, false + } + + return u, true +} From 9f493ebb337bba14a73b62acd2d3b17fa8fa321a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 07:27:14 -0700 Subject: [PATCH 152/195] linting --- .golangci.yml | 289 ++++++++++++++++++ cli/cmd_add.go | 2 +- cli/cmd_mv.go | 6 + cli/cmd_x.go | 75 +++-- cli/config/store.go | 37 +++ .../upgrades/v0.34.0/upgrade_test.go | 2 +- cli/config/yamlstore/yamlstore.go | 73 ++++- cli/diff/table.go | 2 +- cli/logging.go | 19 +- cli/options.go | 3 + cli/output/format/opt.go | 4 +- cli/output/jsonw/errorwriter.go | 2 +- cli/output/jsonw/internal/benchmark_test.go | 2 +- cli/output/jsonw/internal/internal_test.go | 2 +- cli/output/jsonw/jsonw.go | 2 +- cli/output/jsonw/pingwriter.go | 2 +- cli/run.go | 65 +++- drivers/csv/csv_test.go | 10 +- drivers/sqlite3/metadata.go | 4 +- drivers/sqlserver/sqlserver.go | 6 +- drivers/xlsx/ingest.go | 7 +- drivers/xlsx/xlsx.go | 5 +- libsq/core/errz/multi.go | 28 +- libsq/core/errz/stack.go | 4 +- libsq/core/ioz/download/cache.go | 87 ++---- libsq/core/ioz/download/download.go | 116 +++++-- libsq/core/ioz/download/http.go | 22 +- libsq/core/ioz/lockfile/lockfile.go | 7 +- libsq/core/options/opt.go | 10 +- libsq/core/progress/progress.go | 12 +- libsq/core/sqlz/nullbool_test.go | 2 +- libsq/core/stringz/stringz.go | 2 +- libsq/driver/driver.go | 4 +- libsq/driver/grips.go | 91 +++--- libsq/driver/ingest.go | 9 +- libsq/driver/record.go | 10 +- libsq/source/cache.go | 6 +- libsq/source/download.go | 19 +- libsq/source/download_test.go | 8 - libsq/source/files.go | 8 +- testh/testh.go | 13 +- 41 files changed, 777 insertions(+), 300 deletions(-) delete mode 100644 libsq/source/download_test.go diff --git a/.golangci.yml b/.golangci.yml index c321098ff..addbcd38b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -190,6 +190,295 @@ linters-settings: # Default: false require-specific: true + revive: + # https://golangci-lint.run/usage/linters/#revive + rules: + + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#add-constant + - name: add-constant + disabled: true + arguments: + - maxLitCount: "3" + allowStrs: '""' + allowInts: "0,1,2" + allowFloats: "0.0,0.,1.0,1.,2.0,2." + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#argument-limit + - name: argument-limit + disabled: false + arguments: [ 6 ] + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#atomic + - name: atomic + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#banned-characters + - name: banned-characters + disabled: true + arguments: [ "Ω", "Σ", "σ", "7" ] + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#bare-return + - name: bare-return + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#blank-imports + - name: blank-imports + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#bool-literal-in-expr + - name: bool-literal-in-expr + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#call-to-gc + - name: call-to-gc + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#cognitive-complexity + - name: cognitive-complexity + disabled: true + arguments: [ 7 ] + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#comment-spacings + - name: comment-spacings + disabled: false + arguments: + - mypragma + - otherpragma + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#confusing-naming + - name: confusing-naming + disabled: true + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#confusing-results + - name: confusing-results + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#constant-logical-expr + - name: constant-logical-expr + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#context-as-argument + - name: context-as-argument + disabled: false + arguments: + - allowTypesBefore: "*testing.T,*github.com/user/repo/testing.Harness" + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#context-keys-type + - name: context-keys-type + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#cyclomatic + - name: cyclomatic + disabled: true + arguments: [ 3 ] + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#datarace + - name: datarace + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#deep-exit + - name: deep-exit + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#defer + - name: defer + disabled: false + arguments: + - [ "call-chain", "loop" ] + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#dot-imports + - name: dot-imports + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#duplicated-imports + - name: duplicated-imports + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#early-return + - name: early-return + disabled: false + arguments: + - "preserveScope" + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#empty-block + - name: empty-block + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#empty-lines + - name: empty-lines + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#enforce-map-style + - name: enforce-map-style + disabled: true + arguments: + - "make" + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#error-naming + - name: error-naming + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#error-return + - name: error-return + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#error-strings + - name: error-strings + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#errorf + - name: errorf + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#file-header + - name: file-header + disabled: true +# arguments: +# - This is the text that must appear at the top of source files. + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#flag-parameter + - name: flag-parameter + disabled: true + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#function-result-limit + - name: function-result-limit + disabled: false + arguments: [ 4 ] + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#function-length + - name: function-length + disabled: true + arguments: [ 10, 0 ] + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#get-return + - name: get-return + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#identical-branches + - name: identical-branches + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#if-return + - name: if-return + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#increment-decrement + - name: increment-decrement + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#indent-error-flow + - name: indent-error-flow + disabled: false + arguments: + - "preserveScope" + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#import-alias-naming + - name: import-alias-naming + disabled: false + arguments: + - "^[a-z][a-z0-9]{0,}$" + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#imports-blacklist + - name: imports-blacklist + disabled: false + arguments: + - "crypto/md5" + - "crypto/sha1" + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#import-shadowing + - name: import-shadowing + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#line-length-limit + - name: line-length-limit + disabled: true + arguments: [ 80 ] + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#max-public-structs + - name: max-public-structs + disabled: true + arguments: [ 3 ] + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#modifies-parameter + - name: modifies-parameter + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#modifies-value-receiver + - name: modifies-value-receiver + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#nested-structs + - name: nested-structs + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#optimize-operands-order + - name: optimize-operands-order + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#package-comments + - name: package-comments + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#range + - name: range + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#range-val-in-closure + - name: range-val-in-closure + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#range-val-address + - name: range-val-address + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#receiver-naming + - name: receiver-naming + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#redundant-import-alias + - name: redundant-import-alias + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#redefines-builtin-id + - name: redefines-builtin-id + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#string-of-int + - name: string-of-int + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#string-format + - name: string-format + disabled: false + arguments: + - - 'core.WriteError[1].Message' + - '/^([^A-Z]|$)/' + - must not start with a capital letter + - - 'fmt.Errorf[0]' + - '/(^|[^\.!?])$/' + - must not end in punctuation + - - panic + - '/^[^\n]*$/' + - must not contain line breaks + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#struct-tag + - name: struct-tag + arguments: + - "json,inline" + - "bson,outline,gnu" + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#superfluous-else + - name: superfluous-else + disabled: false + arguments: + - "preserveScope" + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#time-equal + - name: time-equal + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#time-naming + - name: time-naming + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#var-naming + - name: var-naming + disabled: false + arguments: + - [ "ID" ] # AllowList + - [ "VM" ] # DenyList + - - upperCaseConst: true + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#var-declaration + - name: var-declaration + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#unconditional-recursion + - name: unconditional-recursion + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#unexported-naming + - name: unexported-naming + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#unexported-return + - name: unexported-return + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#unhandled-error + - name: unhandled-error + disabled: false + arguments: + - "fmt.Printf" + - "fmt.Fprintf" + - "fmt.Fprint" + - "fmt.Fprintln" + - "bytes.Buffer.Write" + - "bytes.Buffer.WriteByte" + - "bytes.Buffer.WriteString" + - "bytes.Buffer.WriteRune" + - "strings.Builder.WriteString" + - "strings.Builder.WriteRune" + - "strings.Builder.Write" + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#unnecessary-stmt + - name: unnecessary-stmt + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#unreachable-code + - name: unreachable-code + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#unused-parameter + - name: unused-parameter + disabled: false + arguments: + - allowRegex: "^_" + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#unused-receiver + - name: unused-receiver + disabled: true + arguments: + - allowRegex: "^_" + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#useless-break + - name: useless-break + disabled: false + # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#waitgroup-by-value + - name: waitgroup-by-value + disabled: false + rowserrcheck: # database/sql is always checked # Default: [] diff --git a/cli/cmd_add.go b/cli/cmd_add.go index 8a5939b31..a2c50e007 100644 --- a/cli/cmd_add.go +++ b/cli/cmd_add.go @@ -335,7 +335,7 @@ func readPassword(ctx context.Context, stdin *os.File, stdout io.Writer, pr *out fmt.Fprint(buf, "Password: ") pr.Faint.Fprint(buf, "[ENTER]") fmt.Fprint(buf, " ") - stdout.Write(buf.Bytes()) + _, _ = stdout.Write(buf.Bytes()) b, err := term.ReadPassword(int(stdin.Fd())) // Regardless of whether there's an error, we print diff --git a/cli/cmd_mv.go b/cli/cmd_mv.go index 1ee4fbd57..b6ff6fbdf 100644 --- a/cli/cmd_mv.go +++ b/cli/cmd_mv.go @@ -51,6 +51,12 @@ source handles are files, and groups are directories.`, } func execMove(cmd *cobra.Command, args []string) error { + if unlock, err := lockReloadConfig(cmd); err != nil { + return err + } else { + defer unlock() + } + switch { case source.IsValidHandle(args[0]) && source.IsValidHandle(args[1]): // Effectively a handle rename diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 53ec33dfd..226346572 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -49,42 +49,53 @@ func newXLockSrcCmd() *cobra.Command { } func execXLockSrcCmd(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - ru := run.FromContext(ctx) - src, err := ru.Config.Collection.Get(args[0]) - if err != nil { - return err - } - - timeout := time.Minute * 20 - lock, err := ru.Files.CacheLockFor(src) - if err != nil { - return err - } - fmt.Fprintf(ru.Out, "Locking cache for source %s with timeout %s for %q [%d]\n\n %s\n\n", - src.Handle, timeout, os.Args[0], os.Getpid(), lock) - - err = lock.Lock(ctx, timeout) - if err != nil { - return err - } - - fmt.Fprintf(ru.Out, "Cache lock acquired for %s\n", src.Handle) - - select { - case <-pressEnter(): - fmt.Fprintln(ru.Out, "\nENTER received, releasing lock") - case <-ctx.Done(): - fmt.Fprintln(ru.Out, "\nContext done, releasing lock") - } - - fmt.Fprintf(ru.Out, "Releasing cache lock for %s\n", src.Handle) - if err = lock.Unlock(); err != nil { + if unlock, err := lockReloadConfig(cmd); err != nil { return err + } else { + defer unlock() } - fmt.Fprintf(ru.Out, "Cache lock released for %s\n", src.Handle) + sleep := time.Second * 10 + fmt.Fprintf(os.Stdout, "huzzah, will sleep for %s\n", sleep) + time.Sleep(sleep) return nil + + //ctx := cmd.Context() + //ru := run.FromContext(ctx) + //src, err := ru.Config.Collection.Get(args[0]) + //if err != nil { + // return err + //} + // + //timeout := time.Minute * 20 + //lock, err := ru.Files.CacheLockFor(src) + //if err != nil { + // return err + //} + //fmt.Fprintf(ru.Out, "Locking cache for source %s with timeout %s for %q [%d]\n\n %s\n\n", + // src.Handle, timeout, os.Args[0], os.Getpid(), lock) + // + //err = lock.Lock(ctx, timeout) + //if err != nil { + // return err + //} + // + //fmt.Fprintf(ru.Out, "Cache lock acquired for %s\n", src.Handle) + // + //select { + //case <-pressEnter(): + // fmt.Fprintln(ru.Out, "\nENTER received, releasing lock") + //case <-ctx.Done(): + // fmt.Fprintln(ru.Out, "\nContext done, releasing lock") + //} + // + //fmt.Fprintf(ru.Out, "Releasing cache lock for %s\n", src.Handle) + //if err = lock.Unlock(); err != nil { + // return err + //} + // + //fmt.Fprintf(ru.Out, "Cache lock released for %s\n", src.Handle) + //return nil } func newXProgressCmd() *cobra.Command { diff --git a/cli/config/store.go b/cli/config/store.go index ae7caa731..1ab29965e 100644 --- a/cli/config/store.go +++ b/cli/config/store.go @@ -2,6 +2,13 @@ package config import ( "context" + "fmt" + "os" + "time" + + "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" + "github.com/neilotoole/sq/libsq/core/options" ) // Store saves and loads config. @@ -15,12 +22,31 @@ type Store interface { // Location returns the location of the store, typically // a file path. Location() string + + // Lockfile returns the lockfile used by the store, but does not acquire + // the lock, which is the caller's responsibility. The lock should always + // be acquired before mutating config. It is also the caller's responsibility + // to release the acquired lock when done. + Lockfile() (lockfile.Lockfile, error) } // DiscardStore implements Store but its Save method is no-op // and Load always returns a new empty Config. Useful for testing. type DiscardStore struct{} +// Lockfile implements Store.Lockfile. +func (DiscardStore) Lockfile() (lockfile.Lockfile, error) { + f, err := os.CreateTemp("", fmt.Sprintf("sq-%d.lock", os.Getpid())) + if err != nil { + return "", errz.Err(err) + } + fname := f.Name() + if err = f.Close(); err != nil { + return "", errz.Err(err) + } + return lockfile.Lockfile(fname), nil +} + var _ Store = (*DiscardStore)(nil) // Load returns a new empty Config. @@ -37,3 +63,14 @@ func (DiscardStore) Save(context.Context, *Config) error { func (DiscardStore) Location() string { return "/dev/null" } + +// OptConfigLockTimeout is the time allowed to acquire the config lock. +var OptConfigLockTimeout = options.NewDuration( + "config.lock.timeout", + "", + 0, + time.Second*5, + "Wait timeout to acquire config lock", + `Wait timeout to acquire the config lock. During this period, retry will occur +if the lock is already held by another process. If zero, no retry occurs.`, +) diff --git a/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go b/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go index e90ebee64..e37853dd9 100644 --- a/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go +++ b/cli/config/yamlstore/upgrades/v0.34.0/upgrade_test.go @@ -13,7 +13,7 @@ import ( "github.com/neilotoole/sq/cli" "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/cli/config/yamlstore" - v0_34_0 "github.com/neilotoole/sq/cli/config/yamlstore/upgrades/v0.34.0" + v0_34_0 "github.com/neilotoole/sq/cli/config/yamlstore/upgrades/v0.34.0" //nolint:revive "github.com/neilotoole/sq/cli/output/format" "github.com/neilotoole/sq/drivers/csv" "github.com/neilotoole/sq/drivers/postgres" diff --git a/cli/config/yamlstore/yamlstore.go b/cli/config/yamlstore/yamlstore.go index e078df2a8..b46fa190f 100644 --- a/cli/config/yamlstore/yamlstore.go +++ b/cli/config/yamlstore/yamlstore.go @@ -8,6 +8,8 @@ import ( "os" "path/filepath" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" + "github.com/neilotoole/sq/cli/buildinfo" "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/libsq/core/errz" @@ -26,6 +28,8 @@ const ( originDefault = "default" ) +var _ config.Store = (*Store)(nil) + // Store provides persistence of config via YAML file. // It implements config.Store. type Store struct { @@ -50,6 +54,16 @@ type Store struct { OptionsRegistry *options.Registry } +// Lockfile implements Store.Lockfile. +func (fs *Store) Lockfile() (lockfile.Lockfile, error) { + fp := filepath.Join(filepath.Dir(fs.Path), "config.lock.pid") + fp, err := filepath.Abs(fp) + if err != nil { + return "", errz.Wrap(err, "failed to get abs path for lockfile") + } + return lockfile.Lockfile(fp), nil +} + // String returns a log/debug-friendly representation. func (fs *Store) String() string { return fmt.Sprintf("config via %s: %v", fs.PathOrigin, fs.Path) @@ -66,27 +80,46 @@ func (fs *Store) Load(ctx context.Context) (*config.Config, error) { log.Debug("Loading config from file", lga.Path, fs.Path) if fs.UpgradeRegistry != nil { - mightNeedUpgrade, foundVers, err := checkNeedsUpgrade(ctx, fs.Path) + mightNeedUpgrade, _, err := checkNeedsUpgrade(ctx, fs.Path) if err != nil { return nil, errz.Wrapf(err, "config: %s", fs.Path) } if mightNeedUpgrade { - log.Info("Upgrade config?", lga.From, foundVers, lga.To, buildinfo.Version) - if _, err = fs.doUpgrade(ctx, foundVers, buildinfo.Version); err != nil { + // The config might need to be upgraded. But, there's an edge case + // where another process might upgrade the config file before we + // get a chance to do so. So, we acquire the config lock, and + // then check again if it still needs upgrade. + unlock, err := fs.acquireLock(ctx) + if err != nil { return nil, err } + defer unlock() - // We do a cycle of loading and saving the config after the upgrade, - // because the upgrade may have written YAML via a map, which - // doesn't preserve order. Loading and saving should fix that. - cfg, err := fs.doLoad(ctx) + // Lock is acquired; check again if config needs upgrade. + var foundVers string + mightNeedUpgrade, foundVers, err = checkNeedsUpgrade(ctx, fs.Path) if err != nil { - return nil, errz.Wrapf(err, "config: %s: load failed after config upgrade", fs.Path) + return nil, errz.Wrapf(err, "config: %s", fs.Path) } - if err = fs.Save(ctx, cfg); err != nil { - return nil, errz.Wrapf(err, "config: %s: save failed after config upgrade", fs.Path) + if mightNeedUpgrade { + log.Info("Upgrade config?", lga.From, foundVers, lga.To, buildinfo.Version) + if _, err = fs.doUpgrade(ctx, foundVers, buildinfo.Version); err != nil { + return nil, err + } + + // We do a cycle of loading and saving the config after the upgrade, + // because the upgrade may have written YAML via a map, which + // doesn't preserve order. Loading and saving should fix that. + cfg, err := fs.doLoad(ctx) + if err != nil { + return nil, errz.Wrapf(err, "config: %s: load failed after config upgrade", fs.Path) + } + + if err = fs.Save(ctx, cfg); err != nil { + return nil, errz.Wrapf(err, "config: %s: save failed after config upgrade", fs.Path) + } } } } @@ -187,6 +220,26 @@ func (fs *Store) fileExists() bool { return err == nil } +// acquireLock acquires the config lock, and returns an unlock func. +// This is an internal convenience method. +func (fs *Store) acquireLock(ctx context.Context) (unlock func(), err error) { + lock, err := fs.Lockfile() + if err != nil { + return nil, errz.Wrap(err, "failed to get config lock") + } + + // We use the default timeout because config isn't loaded yet, + // so we don't know what the value is. + lockTimeout := config.OptConfigLockTimeout.Default() + if err = lock.Lock(ctx, lockTimeout); err != nil { + return nil, errz.Wrap(err, "acquire config lock") + } + + return func() { + lg.WarnIfFuncError(lg.FromContext(ctx), "Release config lock", lock.Unlock) + }, nil +} + // canonicalizeConfig checks cfg's validity, and patches cfg to the canonical // form,cfg's validity. For example, an unknown or nil value in an // options.Options is deleted. diff --git a/cli/diff/table.go b/cli/diff/table.go index d7d823893..2d9954a50 100644 --- a/cli/diff/table.go +++ b/cli/diff/table.go @@ -17,7 +17,7 @@ import ( ) // ExecTableDiff diffs handle1.table1 and handle2.table2. -func ExecTableDiff(ctx context.Context, ru *run.Run, cfg *Config, elems *Elements, +func ExecTableDiff(ctx context.Context, ru *run.Run, cfg *Config, elems *Elements, //nolint:revive handle1, table1, handle2, table2 string, ) error { td1, td2 := &tableData{tblName: table1}, &tableData{tblName: table2} diff --git a/cli/logging.go b/cli/logging.go index c59514c72..95764fbb9 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -251,31 +251,24 @@ func getLogLevel(ctx context.Context, osArgs []string, cfg *config.Config) slog. bootLog.Debug("Using log level specified via flag", lga.Flag, flag.LogLevel, lga.Val, val) lvl := new(slog.Level) - if err = lvl.UnmarshalText([]byte(val)); err != nil { - bootLog.Error("Invalid log level specified via flag", - lga.Flag, flag.LogLevel, - lga.Val, val, - lga.Err, err) - } else { + if err = lvl.UnmarshalText([]byte(val)); err == nil { return *lvl } + bootLog.Error("Invalid log level specified via flag", + lga.Flag, flag.LogLevel, lga.Val, val, lga.Err, err) } val, ok = os.LookupEnv(config.EnvarLogLevel) if ok { bootLog.Debug("Using log level specified via envar", - lga.Env, config.EnvarLogLevel, - lga.Val, val) + lga.Env, config.EnvarLogLevel, lga.Val, val) lvl := new(slog.Level) if err = lvl.UnmarshalText([]byte(val)); err != nil { - bootLog.Error("Invalid log level specified by envar", - lga.Env, config.EnvarLogLevel, - lga.Val, val, - lga.Err, err) - } else { return *lvl } + bootLog.Error("Invalid log level specified by envar", + lga.Env, config.EnvarLogLevel, lga.Val, val, lga.Err, err) } var o options.Options diff --git a/cli/options.go b/cli/options.go index 5c32dfd3f..a6f24c907 100644 --- a/cli/options.go +++ b/cli/options.go @@ -4,6 +4,8 @@ import ( "fmt" "strings" + "github.com/neilotoole/sq/cli/config" + "github.com/samber/lo" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -169,6 +171,7 @@ func RegisterDefaultOpts(reg *options.Registry) { OptCompact, OptPingCmdTimeout, OptShellCompletionTimeout, + config.OptConfigLockTimeout, OptLogEnabled, OptLogFile, OptLogLevel, diff --git a/cli/output/format/opt.go b/cli/output/format/opt.go index 8865269ec..f2ff496dd 100644 --- a/cli/output/format/opt.go +++ b/cli/output/format/opt.go @@ -9,7 +9,9 @@ var _ options.Opt = Opt{} // NewOpt returns a new format.Opt instance. If validFn is non-nil, it // is executed against possible values. -func NewOpt(key, flag string, short rune, defaultVal Format, validFn func(Format) error, usage, help string) Opt { +func NewOpt(key, flag string, short rune, defaultVal Format, //nolint:revive + validFn func(Format) error, usage, help string, +) Opt { opt := options.NewBaseOpt(key, flag, short, usage, help, options.TagOutput) return Opt{BaseOpt: opt, defaultVal: defaultVal, validFn: validFn} } diff --git a/cli/output/jsonw/errorwriter.go b/cli/output/jsonw/errorwriter.go index b98b672a4..e2c18fdb6 100644 --- a/cli/output/jsonw/errorwriter.go +++ b/cli/output/jsonw/errorwriter.go @@ -23,7 +23,7 @@ func NewErrorWriter(log *slog.Logger, out io.Writer, pr *output.Printing) output } type errorDetail struct { - Error string `json:"error,"` + Error string `json:"error"` BaseError string `json:"base_error,omitempty"` Tree string `json:"tree,omitempty"` Stack []*stack `json:"stack,omitempty"` diff --git a/cli/output/jsonw/internal/benchmark_test.go b/cli/output/jsonw/internal/benchmark_test.go index 7eaaaf045..311dd6181 100644 --- a/cli/output/jsonw/internal/benchmark_test.go +++ b/cli/output/jsonw/internal/benchmark_test.go @@ -7,7 +7,7 @@ import ( segmentj "github.com/segmentio/encoding/json" - jcolorenc "github.com/neilotoole/sq/cli/output/jsonw/internal/jcolorenc" + jcolorenc "github.com/neilotoole/sq/cli/output/jsonw/internal/jcolorenc" //nolint:revive "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/sakila" ) diff --git a/cli/output/jsonw/internal/internal_test.go b/cli/output/jsonw/internal/internal_test.go index c42fa5363..d03979022 100644 --- a/cli/output/jsonw/internal/internal_test.go +++ b/cli/output/jsonw/internal/internal_test.go @@ -11,7 +11,7 @@ import ( "github.com/neilotoole/sq/cli/output" "github.com/neilotoole/sq/cli/output/jsonw/internal" - jcolorenc "github.com/neilotoole/sq/cli/output/jsonw/internal/jcolorenc" + jcolorenc "github.com/neilotoole/sq/cli/output/jsonw/internal/jcolorenc" //nolint:revive ) // Encoder encapsulates the methods of a JSON encoder. diff --git a/cli/output/jsonw/jsonw.go b/cli/output/jsonw/jsonw.go index 3bb1b073d..59e81154a 100644 --- a/cli/output/jsonw/jsonw.go +++ b/cli/output/jsonw/jsonw.go @@ -6,7 +6,7 @@ import ( "github.com/neilotoole/sq/cli/output" "github.com/neilotoole/sq/cli/output/jsonw/internal" - jcolorenc "github.com/neilotoole/sq/cli/output/jsonw/internal/jcolorenc" + jcolorenc "github.com/neilotoole/sq/cli/output/jsonw/internal/jcolorenc" //nolint:revive "github.com/neilotoole/sq/libsq/core/errz" ) diff --git a/cli/output/jsonw/pingwriter.go b/cli/output/jsonw/pingwriter.go index c07be28da..1da69068e 100644 --- a/cli/output/jsonw/pingwriter.go +++ b/cli/output/jsonw/pingwriter.go @@ -6,7 +6,7 @@ import ( "github.com/neilotoole/sq/cli/output" "github.com/neilotoole/sq/cli/output/jsonw/internal" - jcolorenc "github.com/neilotoole/sq/cli/output/jsonw/internal/jcolorenc" + jcolorenc "github.com/neilotoole/sq/cli/output/jsonw/internal/jcolorenc" //nolint:revive "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/source" ) diff --git a/cli/run.go b/cli/run.go index d5ce9eee0..ce2b04af9 100644 --- a/cli/run.go +++ b/cli/run.go @@ -6,12 +6,15 @@ import ( "log/slog" "os" "path/filepath" + "time" + + "github.com/neilotoole/sq/libsq/core/progress" "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/cli/config/yamlstore" - v0_34_0 "github.com/neilotoole/sq/cli/config/yamlstore/upgrades/v0.34.0" + v0_34_0 "github.com/neilotoole/sq/cli/config/yamlstore/upgrades/v0.34.0" //nolint:revive "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/drivers/csv" @@ -159,7 +162,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { ru.DriverRegistry = driver.NewRegistry(log) dr := ru.DriverRegistry - ru.Grips = driver.NewGrips(log, dr, ru.Files, scratchSrcFunc) + ru.Grips = driver.NewGrips(dr, ru.Files, scratchSrcFunc) ru.Cleanup.AddC(ru.Grips) dr.AddProvider(sqlite3.Type, &sqlite3.Provider{Log: log}) @@ -276,3 +279,61 @@ func preRun(cmd *cobra.Command, ru *run.Run) error { return FinishRunInit(ctx, ru) } + +// lockReloadConfig acquires the lock for the config store, and updates the +// run (as found on cmd's context) with a fresh copy of the config, loaded +// after lock acquisition. +// +// if unlock, err := lockReloadConfig(cmd); err != nil { +// return err +// } else { +// defer unlock() +// } +// +// The config lock should be acquired before making any changes to config. +// Timeout and progress options from ctx are honored. +// The caller is responsible for invoking the returned unlock func. +func lockReloadConfig(cmd *cobra.Command) (unlock func(), err error) { + ctx := cmd.Context() + ru := run.FromContext(ctx) + if ru.ConfigStore == nil { + return nil, errz.New("config store is nil") + } + + lock, err := ru.ConfigStore.Lockfile() + if err != nil { + return nil, errz.Wrap(err, "failed to get config lock") + } + + lockTimeout := config.OptConfigLockTimeout.Get(options.FromContext(ctx)) + bar := progress.FromContext(ctx).NewTimeoutWaiter( + "Acquire config lock", + time.Now().Add(lockTimeout), + ) + + err = lock.Lock(ctx, lockTimeout) + bar.Stop() + if err != nil { + return nil, errz.Wrap(err, "acquire config lock") + } + + var cfg *config.Config + if cfg, err = ru.ConfigStore.Load(ctx); err != nil { + // An error occurred reloading config; release the lock before returning. + if unlockErr := lock.Unlock(); unlockErr != nil { + lg.FromContext(ctx).Warn("Failed to release config lock", + lga.Lock, lock, lga.Err, unlockErr) + } + return nil, err + } + + // Update the run with the fresh config. + ru.Config = cfg + + return func() { + if unlockErr := lock.Unlock(); unlockErr != nil { + lg.FromContext(ctx).Warn("Failed to release config lock", + lga.Lock, lock, lga.Err, unlockErr) + } + }, nil +} diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index b16d5e8be..b486e727b 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -90,10 +90,10 @@ func TestSakila_query(t *testing.T) { }, } - for _, driver := range []drivertype.Type{csv.TypeCSV, csv.TypeTSV} { - driver := driver + for _, drvr := range []drivertype.Type{csv.TypeCSV, csv.TypeTSV} { + drvr := drvr - t.Run(driver.String(), func(t *testing.T) { + t.Run(drvr.String(), func(t *testing.T) { t.Parallel() for _, tc := range testCases { @@ -105,8 +105,8 @@ func TestSakila_query(t *testing.T) { th := testh.New(t, testh.OptLongOpen()) src := th.Add(&source.Source{ Handle: "@" + tc.file, - Type: driver, - Location: filepath.Join("testdata", "sakila-"+driver.String(), tc.file+"."+driver.String()), + Type: drvr, + Location: filepath.Join("testdata", "sakila-"+drvr.String(), tc.file+"."+drvr.String()), }) sink, err := th.QuerySLQ(src.Handle+".data", nil) diff --git a/drivers/sqlite3/metadata.go b/drivers/sqlite3/metadata.go index cbe6ec48e..a518f91e9 100644 --- a/drivers/sqlite3/metadata.go +++ b/drivers/sqlite3/metadata.go @@ -32,8 +32,8 @@ func recordMetaFromColumnTypes(ctx context.Context, colTypes []*sql.ColumnType, // happens for functions such as COUNT(*). dbTypeName := colType.DatabaseTypeName() - kind := kindFromDBTypeName(ctx, colType.Name(), dbTypeName, colType.ScanType()) - colTypeData := record.NewColumnTypeData(colType, kind) + knd := kindFromDBTypeName(ctx, colType.Name(), dbTypeName, colType.ScanType()) + colTypeData := record.NewColumnTypeData(colType, knd) // It's necessary to explicitly set the scan type because // the backing driver doesn't set it for whatever reason. diff --git a/drivers/sqlserver/sqlserver.go b/drivers/sqlserver/sqlserver.go index 52d29f0a5..f9336c945 100644 --- a/drivers/sqlserver/sqlserver.go +++ b/drivers/sqlserver/sqlserver.go @@ -316,9 +316,9 @@ func (d *driveri) RecordMeta(ctx context.Context, colTypes []*sql.ColumnType) ( sColTypeData := make([]*record.ColumnTypeData, len(colTypes)) ogColNames := make([]string, len(colTypes)) for i, colType := range colTypes { - kind := kindFromDBTypeName(d.log, colType.Name(), colType.DatabaseTypeName()) - colTypeData := record.NewColumnTypeData(colType, kind) - setScanType(colTypeData, kind) + knd := kindFromDBTypeName(d.log, colType.Name(), colType.DatabaseTypeName()) + colTypeData := record.NewColumnTypeData(colType, knd) + setScanType(colTypeData, knd) sColTypeData[i] = colTypeData ogColNames[i] = colTypeData.Name } diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index f4ca914c5..d8e195230 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -134,9 +134,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, x var ingestCount, skipped int for i := range sheetTbls { - if progress.DebugDelay > 0 { - time.Sleep(progress.DebugDelay) - } + progress.DebugDelay() if sheetTbls[i] == nil { // tblDef can be nil if its sheet is empty (has no data). @@ -212,6 +210,7 @@ func ingestSheetToTable(ctx context.Context, destGrip driver.Grip, sheetTbl *she var cells []string i := -1 +LOOP: for iter.Next() { i++ if hasHeader && i == 0 { @@ -245,7 +244,7 @@ func ingestSheetToTable(ctx context.Context, destGrip driver.Grip, sheetTbl *she } // The batch inserter successfully completed - break + break LOOP case bi.RecordCh <- rec: } } diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index fef18e316..95f427977 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -85,10 +85,7 @@ func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Grip, err defer lg.WarnIfCloseError(log, lgm.CloseFileReader, xfile) - if err = ingestXLSX(ctx, p.src, destGrip, xfile); err != nil { - return err - } - return nil + return ingestXLSX(ctx, p.src, destGrip, xfile) } var err error diff --git a/libsq/core/errz/multi.go b/libsq/core/errz/multi.go index b7a45819e..db128daa6 100644 --- a/libsq/core/errz/multi.go +++ b/libsq/core/errz/multi.go @@ -335,9 +335,9 @@ type inspectResult struct { // Inspects the given slice of errors so that we can efficiently allocate // space for it. -func inspect(errors []error) (res inspectResult) { +func inspect(errs []error) (res inspectResult) { first := true - for i, err := range errors { + for i, err := range errs { if err == nil { continue } @@ -359,35 +359,35 @@ func inspect(errors []error) (res inspectResult) { } // fromSlice converts the given list of errors into a single error. -func fromSlice(errors []error) error { +func fromSlice(errs []error) error { // Don't pay to inspect small slices. - switch len(errors) { + switch len(errs) { case 0: return nil case 1: - return errors[0] + return errs[0] } - res := inspect(errors) + res := inspect(errs) switch res.Count { case 0: return nil case 1: // only one non-nil entry - return errors[res.FirstErrorIdx] - case len(errors): + return errs[res.FirstErrorIdx] + case len(errs): if !res.ContainsMultiError { // Error list is flat. Make a copy of it // Otherwise "errors" escapes to the heap // unconditionally for all other cases. // This lets us optimize for the "no errors" case. - out := append(([]error)(nil), errors...) + out := append(([]error)(nil), errs...) return &multiErr{errors: out, stack: callers(1)} } } nonNilErrs := make([]error, 0, res.Capacity) - for _, err := range errors[res.FirstErrorIdx:] { + for _, err := range errs[res.FirstErrorIdx:] { if err == nil { continue } @@ -433,8 +433,8 @@ func fromSlice(errors []error) error { // formatted with %+v. // // fmt.Sprintf("%+v", multierr.Combine(err1, err2)) -func Combine(errors ...error) error { - return fromSlice(errors) +func Combine(errs ...error) error { + return fromSlice(errs) } // Append appends the given errors together. Either value may be nil. @@ -489,8 +489,8 @@ func Append(left, right error) error { // Either right or both, left and right, are multiErrors. Rely on usual // expensive logic. - errors := [2]error{left, right} - return fromSlice(errors[0:]) + errs := [2]error{left, right} + return fromSlice(errs[0:]) } // Unwrap returns a list of errors wrapped by this multierr. diff --git a/libsq/core/errz/stack.go b/libsq/core/errz/stack.go index 343c830af..42db6b478 100644 --- a/libsq/core/errz/stack.go +++ b/libsq/core/errz/stack.go @@ -187,9 +187,9 @@ func (s *stack) Format(st fmt.State, verb rune) { if s == nil { fmt.Fprint(st, "") } - switch verb { //nolint:gocritic + switch verb { //nolint:gocritic,revive case 'v': - switch { //nolint:gocritic + switch { //nolint:gocritic,revive case st.Flag('+'): for _, pc := range *s { f := Frame(pc) diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index 396d4c1c2..f7d7a38f4 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -9,7 +9,6 @@ import ( "net/http/httputil" "os" "path/filepath" - "sync" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" @@ -24,6 +23,7 @@ import ( const ( msgCloseCacheHeaderFile = "Close cached response header file" msgCloseCacheBodyFile = "Close cached response body file" + msgDeleteCache = "Delete HTTP response cache" ) // cache is a cache for an individual download. The cached response is @@ -31,16 +31,14 @@ const ( // a checksum (of the body file) stored in a third file. // Use cache.paths to get the cache file paths. type cache struct { - // FIXME: move the mutex to the Download struct? - mu sync.Mutex - // dir is the directory in which the cache files are stored. + // It is specific to a particular download. dir string } // paths returns the paths to the header, body, and checksum files for req. // It is not guaranteed that they exist. -func (c *cache) paths(req *http.Request) (header, body, checksum string) { +func (c *cache) paths(req *http.Request) (header, body, sum string) { if req == nil || req.Method == http.MethodGet { return filepath.Join(c.dir, "header"), filepath.Join(c.dir, "body"), @@ -59,9 +57,6 @@ func (c *cache) paths(req *http.Request) (header, body, checksum string) { // If it's inconsistent, it will be automatically cleared. // See also: clearIfInconsistent. func (c *cache) exists(req *http.Request) bool { - c.mu.Lock() - defer c.mu.Unlock() - if err := c.clearIfInconsistent(req); err != nil { lg.FromContext(req.Context()).Error("Failed to clear inconsistent cache", lga.Err, err, lga.Dir, c.dir) @@ -118,7 +113,7 @@ func (c *cache) clearIfInconsistent(req *http.Request) error { if inconsistent { lg.FromContext(req.Context()).Warn("Deleting inconsistent cache", lga.Dir, c.dir) - return c.doClear(req.Context()) + return c.clear(req.Context()) } return nil } @@ -126,8 +121,6 @@ func (c *cache) clearIfInconsistent(req *http.Request) error { // Get returns the cached http.Response for req if present, and nil // otherwise. The caller MUST close the returned response body. func (c *cache) get(ctx context.Context, req *http.Request) (*http.Response, error) { - c.mu.Lock() - defer c.mu.Unlock() log := lg.FromContext(ctx) fpHeader, fpBody, _ := c.paths(req) @@ -179,10 +172,6 @@ func (c *cache) get(ctx context.Context, req *http.Request) (*http.Response, err // checksum returns the contents of the cached checksum file, if available. func (c *cache) cachedChecksum(req *http.Request) (sum checksum.Checksum, ok bool) { - if c == nil || req == nil { - return "", false - } - _, _, fp := c.paths(req) if !ioz.FileAccessible(fp) { return "", false @@ -229,16 +218,6 @@ func (c *cache) checksumsMatch(req *http.Request) (sum checksum.Checksum, ok boo // clear deletes the cache entries from disk. func (c *cache) clear(ctx context.Context) error { - if c == nil { - return nil - } - c.mu.Lock() - defer c.mu.Unlock() - - return c.doClear(ctx) -} - -func (c *cache) doClear(ctx context.Context) error { deleteErr := errz.Wrap(os.RemoveAll(c.dir), "delete cache dir") recreateErr := ioz.RequireDir(c.dir) err := errz.Append(deleteErr, recreateErr) @@ -252,27 +231,22 @@ func (c *cache) doClear(ctx context.Context) error { return nil } -const msgDeleteCache = "Delete HTTP response cache" - -// write writes resp header and body to the cache. If headerOnly is true, only -// the header cache file is updated. If headerOnly is false and copyWrtr is -// non-nil, the response body bytes are copied to that destination, as well as -// being written to the cache. If writing to copyWrtr completes successfully, -// it is closed; if there's an error, copyWrtr.Error is invoked. -// A checksum file, computed from the body file, is also written to disk. The -// response body is always closed. +// write writes resp header and body to the cache, returning the number of +// bytes written. +// +// If headerOnly is true, only the header cache file is updated. If headerOnly +// is false and copyWrtr is non-nil, the response body bytes are copied to that +// destination, as well as being written to the cache. +// +// If writing to copyWrtr completes successfully, it is closed; if there's an error, +// copyWrtr.Error is invoked. +// +// A checksum file, computed from the body file, is also written to disk. +// +// The response body is always closed. func (c *cache) write(ctx context.Context, resp *http.Response, headerOnly bool, copyWrtr ioz.WriteErrorCloser, -) error { - c.mu.Lock() - defer c.mu.Unlock() - - return c.doWrite(ctx, resp, headerOnly, copyWrtr) -} - -func (c *cache) doWrite(ctx context.Context, resp *http.Response, - headerOnly bool, copyWrtr ioz.WriteErrorCloser, -) (err error) { +) (written int64, err error) { log := lg.FromContext(ctx) defer func() { @@ -283,12 +257,12 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, if err != nil { log.Warn("Deleting cache because cache write failed", lga.Err, err, lga.Dir, c.dir) - lg.WarnIfError(log, msgDeleteCache, c.doClear(ctx)) + lg.WarnIfError(log, msgDeleteCache, c.clear(ctx)) } }() if err = ioz.RequireDir(c.dir); err != nil { - return err + return 0, err } log.Debug("Writing HTTP response to cache", lga.Dir, c.dir, lga.Resp, httpz.ResponseLogValue(resp)) @@ -296,20 +270,20 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, headerBytes, err := httputil.DumpResponse(resp, false) if err != nil { - return errz.Err(err) + return 0, errz.Err(err) } if _, err = ioz.WriteToFile(ctx, fpHeader, bytes.NewReader(headerBytes)); err != nil { - return err + return written, err } if headerOnly { - return nil + return 0, nil } cacheFile, err := os.OpenFile(fpBody, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) if err != nil { - return err + return 0, err } var r io.Reader = resp.Body @@ -319,35 +293,34 @@ func (c *cache) doWrite(ctx context.Context, resp *http.Response, r = io.TeeReader(resp.Body, copyWrtr) } - var written int64 if written, err = errz.Return(io.Copy(cacheFile, r)); err != nil { log.Error("Cache write: io.Copy failed", lga.Err, err) lg.WarnIfCloseError(log, msgCloseCacheBodyFile, cacheFile) cacheFile = nil - return err + return 0, err } if err = errz.Err(cacheFile.Close()); err != nil { cacheFile = nil - return err + return 0, err } if copyWrtr != nil { if err = errz.Err(copyWrtr.Close()); err != nil { copyWrtr = nil - return err + return 0, err } } sum, err := checksum.ForFile(fpBody) if err != nil { - return errz.Wrap(err, "failed to compute checksum for cache body file") + return 0, errz.Wrap(err, "failed to compute checksum for cache body file") } if err = checksum.WriteFile(filepath.Join(c.dir, "checksums.txt"), sum, "body"); err != nil { - return errz.Wrap(err, "failed to write checksum file for cache body") + return 0, errz.Wrap(err, "failed to write checksum file for cache body") } log.Info("Wrote HTTP response body to cache", lga.Size, written, lga.File, fpBody) - return nil + return written, nil } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index a20cedf80..91ceb2773 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -1,9 +1,6 @@ // Package download provides a mechanism for getting files from // HTTP URLs, making use of a mostly RFC-compliant cache. // -// FIXME: move download to internal/download, because its use -// is so specialized? -// // Acknowledgement: This package is a heavily customized fork // of https://github.com/gregjones/httpcache, via bitcomplete/download. package download @@ -17,6 +14,7 @@ import ( "net/url" "os" "path/filepath" + "sync" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" @@ -80,7 +78,7 @@ func OptDisableCaching(disable bool) Opt { // Download encapsulates downloading a file from a URL, using a local // disk cache if possible. type Download struct { - // FIXME: Does Download need a sync.Mutex? + mu sync.Mutex // name is a user-friendly name, such as a source handle like @data. name string @@ -91,13 +89,20 @@ type Download struct { c *http.Client - cache *cache + // disableCaching, if true, indicates that the cache should not be used. + disableCaching bool + cache *cache // markCachedResponses, if true, indicates that responses returned from the // cache will be given an extra header, X-From-cache. markCachedResponses bool - disableCaching bool + // bodySize is the size of the downloaded file. It is set after + // the download has completed. A value of -1 indicates that it + // has not been set. The Download.Filesize method consults this value, + // but if not set (e.g. Download.Get) has not been invoked, + // Download.Filesize may use the size of the cached file on disk. + bodySize int64 } // New returns a new Download for url that writes to cacheDir. @@ -113,26 +118,30 @@ func New(name string, c *http.Client, dlURL, cacheDir string, opts ...Opt) (*Dow return nil, errz.Err(err) } - t := &Download{ + dl := &Download{ name: name, c: c, url: dlURL, markCachedResponses: true, disableCaching: false, + bodySize: -1, } for _, opt := range opts { - opt(t) + opt(dl) } - if !t.disableCaching { - t.cache = &cache{dir: cacheDir} + if !dl.disableCaching { + dl.cache = &cache{dir: cacheDir} } - return t, nil + return dl, nil } // Get gets the download, invoking Handler as appropriate. func (dl *Download) Get(ctx context.Context, h Handler) { + dl.mu.Lock() + defer dl.mu.Unlock() + req := dl.mustRequest(ctx) lg.FromContext(ctx).Debug("Get download", lga.URL, dl.url) dl.get(req, h) @@ -201,7 +210,8 @@ func (dl *Download) get(req *http.Request, h Handler) { //nolint:funlen,gocognit } resp, err = dl.do(req) //nolint:bodyclose - if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { + switch { + case err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified: // Replace the 304 response with the one from cache, but update with some new headers endToEndHeaders := getEndToEndHeaders(resp.Header) for _, header := range endToEndHeaders { @@ -209,17 +219,17 @@ func (dl *Download) get(req *http.Request, h Handler) { //nolint:funlen,gocognit } lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) resp = cachedResp - } else if (err != nil || - (cachedResp != nil && - resp.StatusCode >= 500)) && + + case (err != nil || (resp.StatusCode >= 500)) && req.Method == http.MethodGet && - canStaleOnError(cachedResp.Header, req.Header) { + canStaleOnError(cachedResp.Header, req.Header): // In case of transport failure and stale-if-error activated, returns cached content // when available log.Warn("Returning cached response due to transport failure", lga.Err, err) h.Cached(fpBody) return - } else { + + default: if err != nil || resp.StatusCode != http.StatusOK { lg.WarnIfError(log, msgDeleteCache, dl.cache.clear(req.Context())) } @@ -231,7 +241,7 @@ func (dl *Download) get(req *http.Request, h Handler) { //nolint:funlen,gocognit } else { reqCacheControl := parseCacheControl(req.Header) if _, ok := reqCacheControl["only-if-cached"]; ok { - resp = newGatewayTimeoutResponse(req) //nolint:bodyclose + resp = newGatewayTimeoutResponse(req) } else { resp, err = dl.do(req) //nolint:bodyclose if err != nil { @@ -253,7 +263,7 @@ func (dl *Download) get(req *http.Request, h Handler) { //nolint:funlen,gocognit if resp == cachedResp { lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - if err = dl.cache.write(ctx, resp, true, nil); err != nil { + if dl.bodySize, err = dl.cache.write(ctx, resp, true, nil); err != nil { log.Error("Failed to update cache header", lga.Dir, dl.cache.dir, lga.Err, err) h.Error(err) return @@ -270,7 +280,7 @@ func (dl *Download) get(req *http.Request, h Handler) { //nolint:funlen,gocognit } defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, resp.Body) - if err = dl.cache.write(req.Context(), resp, false, destWrtr); err != nil { + if dl.bodySize, err = dl.cache.write(req.Context(), resp, false, destWrtr); err != nil { log.Error("Failed to write download cache", lga.Dir, dl.cache.dir, lga.Err, err) // We don't need to explicitly call Handler.Error here, because the caller is // informed via destWrtr.Error, which has already been invoked by cache.write. @@ -291,7 +301,7 @@ func (dl *Download) get(req *http.Request, h Handler) { //nolint:funlen,gocognit cr := contextio.NewReader(ctx, resp.Body) defer lg.WarnIfCloseError(log, lgm.CloseHTTPResponseBody, cr.(io.ReadCloser)) - _, err = io.Copy(destWrtr, cr) + dl.bodySize, err = io.Copy(destWrtr, cr) if err != nil { log.Error("Failed to copy download to dest writer", lga.Err, err) destWrtr.Error(err) @@ -344,14 +354,19 @@ func (dl *Download) mustRequest(ctx context.Context) *http.Request { // Clear deletes the cache. func (dl *Download) Clear(ctx context.Context) error { - if dl.cache != nil { - return dl.cache.clear(ctx) + dl.mu.Lock() + defer dl.mu.Unlock() + if dl.cache == nil { + return nil } - return nil + + return dl.cache.clear(ctx) } // State returns the Download's cache state. func (dl *Download) State(ctx context.Context) State { + dl.mu.Lock() + defer dl.mu.Unlock() return dl.state(dl.mustRequest(ctx)) } @@ -385,28 +400,63 @@ func (dl *Download) state(req *http.Request) State { return getFreshness(cachedResp.Header, req.Header) } -// CacheFile returns the path to the cached file and its size, if it exists -// and has been fully downloaded. -func (dl *Download) CacheFile(ctx context.Context) (fp string, size int64, err error) { +// Filesize returns the size of the downloaded file. This should +// be invoked after the download has completed. +func (dl *Download) Filesize(ctx context.Context) (int64, error) { + dl.mu.Lock() + defer dl.mu.Unlock() + if dl.cache == nil { - return "", 0, errz.Errorf("cache doesn't exist for: %s", dl.url) + // There's no cache, so we can only get the value via + // the bodySize field. + if dl.bodySize < 0 { + return 0, errz.New("download file size not available") + } + return dl.bodySize, nil } req := dl.mustRequest(ctx) if !dl.cache.exists(req) { - return "", 0, errz.Errorf("no cache for: %s", dl.url) + // It's not in the cache. + if dl.bodySize < 0 { + return 0, errz.New("download file size not available") + } + return dl.bodySize, nil } - _, fp, _ = dl.cache.paths(req) + // It's in the cache + _, fp, _ := dl.cache.paths(req) fi, err := os.Stat(fp) if err != nil { - return "", 0, errz.Err(err) + return 0, errz.Wrapf(err, "unable to stat cached download file: %s", fp) } - return fp, fi.Size(), nil + + return fi.Size(), nil +} + +// CacheFile returns the path to the cached file, if it exists, +// and has been fully downloaded. +func (dl *Download) CacheFile(ctx context.Context) (fp string, err error) { + dl.mu.Lock() + defer dl.mu.Unlock() + + if dl.cache == nil { + return "", errz.Errorf("cache doesn't exist for: %s", dl.url) + } + + req := dl.mustRequest(ctx) + if !dl.cache.exists(req) { + return "", errz.Errorf("no cache for: %s", dl.url) + } + _, fp, _ = dl.cache.paths(req) + return fp, nil } // Checksum returns the checksum of the cached download, if available. func (dl *Download) Checksum(ctx context.Context) (sum checksum.Checksum, ok bool) { + dl.mu.Lock() + defer dl.mu.Unlock() + if dl.cache == nil { return "", false } @@ -416,7 +466,7 @@ func (dl *Download) Checksum(ctx context.Context) (sum checksum.Checksum, ok boo } func (dl *Download) isCacheable(req *http.Request) bool { - if dl.disableCaching { + if dl.cache == nil || dl.disableCaching { return false } return (req.Method == http.MethodGet || req.Method == http.MethodHead) && req.Header.Get("range") == "" diff --git a/libsq/core/ioz/download/http.go b/libsq/core/ioz/download/http.go index 98da338fa..0e00649d3 100644 --- a/libsq/core/ioz/download/http.go +++ b/libsq/core/ioz/download/http.go @@ -137,24 +137,22 @@ func canStaleOnError(respHeaders, reqHeaders http.Header) bool { lifetime := time.Duration(-1) if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok { - if staleMaxAge != "" { - lifetime, err = time.ParseDuration(staleMaxAge + "s") - if err != nil { - return false - } - } else { + if staleMaxAge == "" { return true } + lifetime, err = time.ParseDuration(staleMaxAge + "s") + if err != nil { + return false + } } if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok { - if staleMaxAge != "" { - lifetime, err = time.ParseDuration(staleMaxAge + "s") - if err != nil { - return false - } - } else { + if staleMaxAge == "" { return true } + lifetime, err = time.ParseDuration(staleMaxAge + "s") + if err != nil { + return false + } } if lifetime >= 0 { diff --git a/libsq/core/ioz/lockfile/lockfile.go b/libsq/core/ioz/lockfile/lockfile.go index 9997a9cc4..6d7fc9e50 100644 --- a/libsq/core/ioz/lockfile/lockfile.go +++ b/libsq/core/ioz/lockfile/lockfile.go @@ -20,7 +20,7 @@ import ( type Lockfile string // New returns a new Lockfile instance. Arg fp must be -// an absolute path (but the path may not exist). +// an absolute path (but it's legal for the path to not exist). func New(fp string) (Lockfile, error) { lf, err := lockfile.New(fp) if err != nil { @@ -38,12 +38,12 @@ func (l Lockfile) Lock(ctx context.Context, timeout time.Duration) error { dir := filepath.Dir(string(l)) if err := ioz.RequireDir(dir); err != nil { - return errz.Wrapf(err, "failed create parent dir of cache lock: %s", string(l)) + return errz.Wrapf(err, "failed to create parent dir of lockfile: %s", string(l)) } if timeout == 0 { if err := lockfile.Lockfile(l).TryLock(); err != nil { - log.Warn("Failed to acquire pid lock", lga.Err, err) + log.Warn("Failed to acquire pid lock", lga.Path, string(l), lga.Err, err) return errz.Wrapf(err, "failed to acquire pid lock: %s", l) } log.Debug("Acquired pid lock") @@ -61,7 +61,6 @@ func (l Lockfile) Lock(ctx context.Context, timeout time.Duration) error { return nil } - // log.Debug("Failed to acquire pid lock, may retry", lga.Attempts, attempts, lga.Err, err) return err }, errz.Has[lockfile.TemporaryError], diff --git a/libsq/core/options/opt.go b/libsq/core/options/opt.go index 0b13ed410..552ad67da 100644 --- a/libsq/core/options/opt.go +++ b/libsq/core/options/opt.go @@ -193,6 +193,8 @@ var _ Opt = String{} // NewString returns an options.String instance. If flag is empty, the // value of key is used. If valid Fn is non-nil, it is called from // the process function. +// +//nolint:revive func NewString(key, flag string, short rune, defaultVal string, validFn func(string) error, usage, help string, tags ...string, ) String { @@ -275,7 +277,7 @@ var _ Opt = Int{} // NewInt returns an options.Int instance. If flag is empty, the // value of key is used. -func NewInt(key, flag string, short rune, defaultVal int, usage, help string, tags ...string) Int { +func NewInt(key, flag string, short rune, defaultVal int, usage, help string, tags ...string) Int { //nolint:revive return Int{ BaseOpt: NewBaseOpt(key, flag, short, usage, help, tags...), defaultVal: defaultVal, @@ -405,7 +407,7 @@ var _ Opt = Bool{} // of key is used. If invertFlag is true, the flag's boolean value // is inverted to set the option. For example, if the Opt is "progress", // and the flag is "--no-progress", then invertFlag should be true. -func NewBool(key, flag string, invertFlag bool, short rune, +func NewBool(key, flag string, invertFlag bool, short rune, //nolint:revive defaultVal bool, usage, help string, tags ...string, ) Bool { return Bool{ @@ -519,7 +521,9 @@ var _ Opt = Duration{} // NewDuration returns an options.Duration instance. If flag is empty, the // value of key is used. -func NewDuration(key, flag string, short rune, defaultVal time.Duration, usage, help string, tags ...string) Duration { +func NewDuration(key, flag string, short rune, defaultVal time.Duration, //nolint:revive + usage, help string, tags ...string, +) Duration { return Duration{ BaseOpt: NewBaseOpt(key, flag, short, usage, help, tags...), defaultVal: defaultVal, diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 15d58a4e9..b22a655f9 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -29,12 +29,16 @@ import ( "github.com/neilotoole/sq/libsq/core/stringz" ) -// DebugDelay is a duration that parts of the codebase sleep for to -// facilitate testing the progress impl. It should be removed before -// release. +// DebugDelay sleeps for a period of time to facilitate testing the +// progress impl. It should be removed before release. // // Deprecated: This is a temporary hack for testing. -const DebugDelay = time.Millisecond * 0 +func DebugDelay() { + const delay = time.Millisecond * 0 + if delay > 0 { + time.Sleep(delay) + } +} type ctxKey struct{} diff --git a/libsq/core/sqlz/nullbool_test.go b/libsq/core/sqlz/nullbool_test.go index 1c6215f33..1724e4310 100644 --- a/libsq/core/sqlz/nullbool_test.go +++ b/libsq/core/sqlz/nullbool_test.go @@ -43,7 +43,7 @@ func TestNullBool_Scan(t *testing.T) { err := nb.Scan(tt.input) if err != nil { - if tt.expectValid == false { + if !tt.expectValid { continue } t.Errorf("[%d] {%s}: did not expect error: %v", i, tt.input, err) diff --git a/libsq/core/stringz/stringz.go b/libsq/core/stringz/stringz.go index 19954cd2c..f960e5023 100644 --- a/libsq/core/stringz/stringz.go +++ b/libsq/core/stringz/stringz.go @@ -246,7 +246,7 @@ func stringWithCharset(length int, charset string) string { b := make([]byte, length) for i := range b { - b[i] = charset[rand.Intn(len(charset))] //#nosec G404 // Doesn't need to be strongly random + b[i] = charset[rand.Intn(len(charset))] //nolint:gosec } return string(b) diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index fc6951ba5..a696ed0a9 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -38,9 +38,9 @@ type Driver interface { // ValidateSource verifies that the source is valid for this driver. It // may transform the source into a canonical form, which is returned in - // the "src" return value (the original source is not changed). An error + // the return value (the original source is not changed). An error // is returned if the source is invalid. - ValidateSource(source *source.Source) (src *source.Source, err error) + ValidateSource(src *source.Source) (*source.Source, error) // Ping verifies that the source is reachable, or returns an error if not. // The exact behavior of Ping() is driver-dependent. diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index c8806981d..a11355d43 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -2,11 +2,12 @@ package driver import ( "context" - "log/slog" "strings" "sync" "time" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" @@ -14,7 +15,6 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" - "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" ) @@ -29,7 +29,6 @@ type ScratchSrcFunc func(ctx context.Context, name string) (src *source.Source, // Note that at this time instances returned by Open are cached // and then closed by Close. This may be a bad approach. type Grips struct { - log *slog.Logger drvrs Provider mu sync.Mutex scratchSrcFn ScratchSrcFunc @@ -39,11 +38,8 @@ type Grips struct { } // NewGrips returns a Grips instances. -func NewGrips(log *slog.Logger, drvrs Provider, - files *source.Files, scratchSrcFn ScratchSrcFunc, -) *Grips { +func NewGrips(drvrs Provider, files *source.Files, scratchSrcFn ScratchSrcFunc) *Grips { return &Grips{ - log: log, drvrs: drvrs, mu: sync.Mutex{}, scratchSrcFn: scratchSrcFn, @@ -131,6 +127,7 @@ func (gs *Grips) doOpen(ctx context.Context, src *source.Source) (Grip, error) { // its Close method will be invoked by Grips.Close. func (gs *Grips) OpenScratch(ctx context.Context, src *source.Source) (Grip, error) { const msgCloseScratch = "Close scratch db" + log := lg.FromContext(ctx) cacheDir, srcCacheDBFilepath, _, err := gs.files.CachePaths(src) if err != nil { @@ -146,18 +143,18 @@ func (gs *Grips) OpenScratch(ctx context.Context, src *source.Source) (Grip, err // if err is non-nil, cleanup is guaranteed to be nil return nil, err } - gs.log.Debug("Opening scratch src", lga.Src, scratchSrc) + log.Debug("Opening scratch src", lga.Src, scratchSrc) backingDrvr, err := gs.drvrs.DriverFor(scratchSrc.Type) if err != nil { - lg.WarnIfFuncError(gs.log, msgCloseScratch, cleanFn) + lg.WarnIfFuncError(log, msgCloseScratch, cleanFn) return nil, err } var backingGrip Grip backingGrip, err = backingDrvr.Open(ctx, scratchSrc) if err != nil { - lg.WarnIfFuncError(gs.log, msgCloseScratch, cleanFn) + lg.WarnIfFuncError(log, msgCloseScratch, cleanFn) return nil, err } @@ -171,25 +168,48 @@ func (gs *Grips) OpenScratch(ctx context.Context, src *source.Source) (Grip, err return backingGrip, nil } -// OpenIngest implements driver.GripOpenIngester. +// OpenIngest implements driver.GripOpenIngester. It opens a Grip, ingesting +// the source into the Grip. If allowCache is true, the ingest cache DB +// is used if possible. If allowCache is false, any existing ingest cache DB +// is not utilized, and is overwritten by the ingestion process. func (gs *Grips) OpenIngest(ctx context.Context, src *source.Source, allowCache bool, ingestFn func(ctx context.Context, dest Grip) error, ) (Grip, error) { - var grip Grip - var err error - - if !allowCache || src.Handle == source.StdinHandle { - // We don't currently cache stdin. Probably we never will? - grip, err = gs.openIngestNoCache(ctx, src, ingestFn) - } else { - grip, err = gs.openIngestCache(ctx, src, ingestFn) + // Get the cache lock for src, no matter if we're making + // use of the ingest cache DB or not. We do this to prevent + // another process from overwriting the cache DB while it's + // being written to. + lock, err := gs.files.CacheLockFor(src) + if err != nil { + return nil, err } + lockTimeout := source.OptCacheLockTimeout.Get(options.FromContext(ctx)) + bar := progress.FromContext(ctx).NewTimeoutWaiter( + src.Handle+": acquire lock", + time.Now().Add(lockTimeout), + ) + + err = lock.Lock(ctx, lockTimeout) + bar.Stop() if err != nil { - return nil, err + return nil, errz.Wrap(err, src.Handle+": acquire cache lock") } - return grip, nil + defer func() { + if err = lock.Unlock(); err != nil { + lg.FromContext(ctx).Warn("Failed to release cache lock", + lga.Lock, lock, lga.Err, err) + } + }() + + if !allowCache || src.Handle == source.StdinHandle { + // Note that we can never cache stdin, because it's a stream + // that is effectively unique each time. + return gs.openIngestNoCache(ctx, src, ingestFn) + } + + return gs.openIngestCache(ctx, src, ingestFn) } func (gs *Grips) openIngestNoCache(ctx context.Context, src *source.Source, @@ -214,7 +234,7 @@ func (gs *Grips) openIngestNoCache(ctx context.Context, src *source.Source, return nil, err } - gs.log.Info("Ingest completed", + log.Info("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) return impl, nil @@ -226,31 +246,9 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, log := lg.FromContext(ctx).With(lga.Handle, src.Handle) ctx = lg.NewContext(ctx, log) - lock, err := gs.files.CacheLockFor(src) - if err != nil { - return nil, err - } - - lockTimeout := source.OptCacheLockTimeout.Get(options.FromContext(ctx)) - bar := progress.FromContext(ctx).NewTimeoutWaiter( - src.Handle+": acquire lock", - time.Now().Add(lockTimeout), - ) - - err = lock.Lock(ctx, lockTimeout) - bar.Stop() - if err != nil { - return nil, errz.Wrap(err, src.Handle+": acquire cache lock") - } - - defer func() { - if err = lock.Unlock(); err != nil { - log.Warn("Failed to release cache lock", lga.Lock, lock, lga.Err, err) - } - }() - var impl Grip var foundCached bool + var err error if impl, foundCached, err = gs.openCachedFor(ctx, src); err != nil { return nil, err } @@ -325,12 +323,11 @@ func (gs *Grips) OpenJoin(ctx context.Context, srcs ...*source.Source) (Grip, er names = append(names, src.Handle[1:]) } - gs.log.Debug("OpenJoin", "sources", strings.Join(names, ",")) + lg.FromContext(ctx).Debug("OpenJoin", "sources", strings.Join(names, ",")) return gs.OpenScratch(ctx, srcs[0]) } // Close closes d, invoking Close on any instances opened via d.Open. func (gs *Grips) Close() error { - gs.log.Debug("Closing databases(s)...", lga.Count, gs.clnup.Len()) return gs.clnup.Run() } diff --git a/libsq/driver/ingest.go b/libsq/driver/ingest.go index 8785dd644..49c426052 100644 --- a/libsq/driver/ingest.go +++ b/libsq/driver/ingest.go @@ -26,8 +26,6 @@ to detect the header.`, ) // OptIngestCache specifies whether ingested data is cached or not. -// -// REVISIT: Maybe rename ingest.cache simply to "cache"? var OptIngestCache = options.NewBool( "ingest.cache", "", @@ -35,12 +33,15 @@ var OptIngestCache = options.NewBool( 0, true, "Ingest data is cached", - `Specifies whether ingested data is cached or not.`, + `Specifies whether ingested data is cached or not. When data is ingested +from a document source, it is stored in a cache DB. Subsequent uses of that same +source will use that cached DB instead of ingesting the data again, unless this +option is set to false, in which case, the data is ingested each time.`, options.TagSource, ) // OptIngestSampleSize specifies the number of samples that a detector -// should take to determine type. +// should take to determine ingest data type. var OptIngestSampleSize = options.NewInt( "ingest.sample-size", "", diff --git a/libsq/driver/record.go b/libsq/driver/record.go index 2989065f5..08c511b5b 100644 --- a/libsq/driver/record.go +++ b/libsq/driver/record.go @@ -382,7 +382,7 @@ func (bi *BatchInsert) Munge(rec []any) error { // it must be a sql.Conn or sql.Tx. // //nolint:gocognit -func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, +func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, //nolint:revive destTbl string, destColNames []string, batchSize int, ) (*BatchInsert, error) { log := lg.FromContext(ctx) @@ -468,9 +468,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, bi.written.Add(affected) pbar.IncrBy(int(affected)) - if progress.DebugDelay > 0 { - time.Sleep(progress.DebugDelay) - } + progress.DebugDelay() if rec == nil { // recCh is closed (coincidentally exactly on the @@ -514,9 +512,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, bi.written.Add(affected) pbar.IncrBy(int(affected)) - if progress.DebugDelay > 0 { - time.Sleep(progress.DebugDelay) - } + progress.DebugDelay() // We're done return diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 751f2c70d..269439d7f 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -24,7 +24,7 @@ import ( "github.com/neilotoole/sq/libsq/core/stringz" ) -// OptCacheLockTimeout is the time allowed to acquire cache lock. +// OptCacheLockTimeout is the time allowed to acquire a cache lock. // // See also: [driver.OptIngestCache]. var OptCacheLockTimeout = options.NewDuration( @@ -35,8 +35,6 @@ var OptCacheLockTimeout = options.NewDuration( "Wait timeout to acquire cache lock", `Wait timeout to acquire cache lock. During this period, retry will occur if the lock is already held by another process. If zero, no retry occurs.`, - options.TagSource, - options.TagSQL, ) // CacheDirFor gets the cache dir for handle. It is not guaranteed @@ -188,7 +186,7 @@ func (fs *Files) CachePaths(src *Source) (srcCacheDir, cacheDB, checksums string } checksums = filepath.Join(srcCacheDir, "checksums.txt") - cacheDB = filepath.Join(srcCacheDir, "cached.db") + cacheDB = filepath.Join(srcCacheDir, "cache.sqlite.db") return srcCacheDir, cacheDB, checksums, nil } diff --git a/libsq/source/download.go b/libsq/source/download.go index 99b78d365..28adeaf14 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -23,7 +23,7 @@ var OptHTTPRequestTimeout = options.NewDuration( 0, time.Second*10, "HTTP/S request initial response timeout duration", - `How long to wait for initial response from HTTP/S endpoint before + `How long to wait for initial response from a HTTP/S endpoint before timeout occurs. Reading the body of the response, such as large HTTP file downloads, is not affected by this option. Example: 500ms or 3s. Contrast with http.response.timeout.`, @@ -35,7 +35,7 @@ var OptHTTPResponseTimeout = options.NewDuration( "", 0, 0, - "HTTP/S response completion timeout duration", + "HTTP/S request completion timeout duration", `How long to wait for the entire HTTP transaction to complete. This includes reading the body of the response, such as large HTTP file downloads. Typically this is set to 0, indicating no timeout. Contrast with http.request.timeout.`, @@ -50,10 +50,16 @@ var OptHTTPSInsecureSkipVerify = options.NewBool( false, "Skip HTTPS TLS verification", "Skip HTTPS TLS verification. Useful when downloading against self-signed certs.", + options.TagSource, ) +// downloadFor returns the download.Download for src, creating +// and caching it if necessary. func (fs *Files) downloadFor(ctx context.Context, src *Source) (*download.Download, error) { - // REVISIT: should downloadFor return a cached instance of download.Download? + dl, ok := fs.downloads[src.Handle] + if ok { + return dl, nil + } dlDir, err := fs.downloadDirFor(src) if err != nil { @@ -70,11 +76,10 @@ func (fs *Files) downloadFor(ctx context.Context, src *Source) (*download.Downlo httpz.OptInsecureSkipVerify(OptHTTPSInsecureSkipVerify.Get(o)), ) - dl, err := download.New(src.Handle, c, src.Location, dlDir) - if err != nil { + if dl, err = download.New(src.Handle, c, src.Location, dlDir); err != nil { return nil, err } - + fs.downloads[src.Handle] = dl return dl, nil } @@ -111,7 +116,7 @@ func (fs *Files) openRemoteFile(ctx context.Context, src *Source, checkFresh boo if !checkFresh && fs.fscache.Exists(loc) { // If the download has completed, dl.CacheFile will return the // path to the cached file. - cachedDownload, _, err = dl.CacheFile(ctx) + cachedDownload, err = dl.CacheFile(ctx) if err != nil { return "", nil, err } diff --git a/libsq/source/download_test.go b/libsq/source/download_test.go deleted file mode 100644 index 1b62a5d77..000000000 --- a/libsq/source/download_test.go +++ /dev/null @@ -1,8 +0,0 @@ -package source - -const ( - urlPaymentLargeCSV = "https://sqio-public.s3.amazonaws.com/testdata/payment-large.gen.csv" - urlActorCSV = "https://sq.io/testdata/actor.csv" - sizeActorCSV = int64(7641) - sizeGzipActorCSV = int64(1968) -) diff --git a/libsq/source/files.go b/libsq/source/files.go index fac3b4eac..698371c0d 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -45,6 +45,10 @@ type Files struct { clnup *cleanup.Cleanup optRegistry *options.Registry + // downloads is a map of source handles the download.Download + // for that source. + downloads map[string]*download.Download + // fscache is used to cache files, providing convenient access // to multiple readers via Files.newReader. fscache *fscache.FSCache @@ -84,6 +88,7 @@ func NewFiles(ctx context.Context, optReg *options.Registry, tempDir: tmpDir, clnup: cleanup.New(), log: lg.FromContext(ctx), + downloads: map[string]*download.Download{}, } // We want a unique dir for each execution. Note that fcache is deleted @@ -142,8 +147,7 @@ func (fs *Files) Filesize(ctx context.Context, src *Source) (size int64, err err return 0, err } - _, size, err = dl.CacheFile(ctx) - return size, err + return dl.Filesize(ctx) case locTypeSQL: return 0, errz.Errorf("invalid to get size of SQL source: %s", src.Handle) diff --git a/testh/testh.go b/testh/testh.go index 801819d93..1531b665b 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -136,7 +136,9 @@ func New(t testing.TB, opts ...Option) *Helper { // NewWith is a convenience wrapper for New, that also returns // the source.Source for handle, the driver.SQLDriver, driver.Grip, // and the *sql.DB. -func NewWith(t testing.TB, handle string) (*Helper, *source.Source, driver.SQLDriver, driver.Grip, *sql.DB) { +func NewWith(t testing.TB, handle string) (*Helper, *source.Source, //nolint:revive + driver.SQLDriver, driver.Grip, *sql.DB, +) { th := New(t) src := th.Source(handle) grip := th.Open(src) @@ -166,7 +168,7 @@ func (h *Helper) init() { assert.NoError(h.T, err) }) - h.grips = driver.NewGrips(log, h.registry, h.files, sqlite3.NewScratchSource) + h.grips = driver.NewGrips(h.registry, h.files, sqlite3.NewScratchSource) h.Cleanup.AddC(h.grips) h.registry.AddProvider(sqlite3.Type, &sqlite3.Provider{Log: log}) @@ -859,8 +861,11 @@ func DriverDetectors() []source.DriverDetectFunc { return []source.DriverDetectFunc{ source.DetectMagicNumber, xlsx.DetectXLSX, - csv.DetectCSV, csv.DetectTSV, - /*json.DetectJSON,*/ json.DetectJSONA(1000), json.DetectJSONL(1000), // FIXME: enable DetectJSON when it's ready + csv.DetectCSV, + csv.DetectTSV, + // json.DetectJSON(1000), // FIXME: enable DetectJSON when it's ready + json.DetectJSONA(1000), + json.DetectJSONL(1000), } } From 7d78a7516699cbb912de9a482c1ba5201d64424a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 08:23:55 -0700 Subject: [PATCH 153/195] linting --- cli/cli.go | 1 + cli/cmd_add.go | 3 +- cli/cmd_cache.go | 4 +- cli/cmd_config_edit.go | 2 +- cli/cmd_config_set.go | 2 +- cli/cmd_mv.go | 7 +- cli/cmd_remove.go | 1 + cli/cmd_scratch.go | 1 + cli/cmd_version.go | 5 +- cli/cmd_x.go | 117 ++++++++++++++++---------------- cli/run.go | 31 ++++++++- libsq/core/progress/progress.go | 2 +- 12 files changed, 102 insertions(+), 74 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index 8222095aa..c81948780 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -254,6 +254,7 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { xCmd := addCmd(ru, rootCmd, newXCmd()) addCmd(ru, xCmd, newXLockSrcCmd()) + addCmd(ru, xCmd, newXLockConfigCmd()) addCmd(ru, xCmd, newXProgressCmd()) addCmd(ru, xCmd, newXDownloadCmd()) diff --git a/cli/cmd_add.go b/cli/cmd_add.go index a2c50e007..19df2e5fe 100644 --- a/cli/cmd_add.go +++ b/cli/cmd_add.go @@ -25,7 +25,7 @@ import ( "github.com/neilotoole/sq/libsq/source/drivertype" ) -func newSrcAddCmd() *cobra.Command { +func newSrcAddCmd() *cobra.Command { //nolint:funlen cmd := &cobra.Command{ Use: "add [--handle @HANDLE] LOCATION", RunE: execSrcAdd, @@ -156,6 +156,7 @@ More examples: Long: `Add data source specified by LOCATION, optionally identified by @HANDLE.`, } + markCmdRequiresConfigLock(cmd) addTextFormatFlags(cmd) cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage) diff --git a/cli/cmd_cache.go b/cli/cmd_cache.go index 285380a53..79bf1c819 100644 --- a/cli/cmd_cache.go +++ b/cli/cmd_cache.go @@ -105,6 +105,7 @@ func newCacheClearCmd() *cobra.Command { Example: ` $ sq cache clear`, } + markCmdRequiresConfigLock(cmd) return cmd } @@ -155,7 +156,7 @@ func newCacheEnableCmd() *cobra.Command { }, Example: ` $ sq cache enable`, } - + markCmdRequiresConfigLock(cmd) return cmd } @@ -173,5 +174,6 @@ func newCacheDisableCmd() *cobra.Command { Example: ` $ sq cache disable`, } + markCmdRequiresConfigLock(cmd) return cmd } diff --git a/cli/cmd_config_edit.go b/cli/cmd_config_edit.go index e7724233b..e9db945f8 100644 --- a/cli/cmd_config_edit.go +++ b/cli/cmd_config_edit.go @@ -51,7 +51,7 @@ in envar $SQ_EDITOR or $EDITOR.`, # Use a different editor $ SQ_EDITOR=nano sq config edit`, } - + markCmdRequiresConfigLock(cmd) return cmd } diff --git a/cli/cmd_config_set.go b/cli/cmd_config_set.go index ada9f28d1..3fbe35191 100644 --- a/cli/cmd_config_set.go +++ b/cli/cmd_config_set.go @@ -41,7 +41,7 @@ Use "sq config ls -v" to list available options.`, # Help for an individual option $ sq config set conn.max-open --help`, } - + markCmdRequiresConfigLock(cmd) addTextFormatFlags(cmd) cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage) diff --git a/cli/cmd_mv.go b/cli/cmd_mv.go index b6ff6fbdf..34a0c806a 100644 --- a/cli/cmd_mv.go +++ b/cli/cmd_mv.go @@ -42,6 +42,7 @@ source handles are files, and groups are directories.`, $ sq mv production prod`, } + markCmdRequiresConfigLock(cmd) addTextFormatFlags(cmd) cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage) @@ -51,12 +52,6 @@ source handles are files, and groups are directories.`, } func execMove(cmd *cobra.Command, args []string) error { - if unlock, err := lockReloadConfig(cmd); err != nil { - return err - } else { - defer unlock() - } - switch { case source.IsValidHandle(args[0]) && source.IsValidHandle(args[1]): // Effectively a handle rename diff --git a/cli/cmd_remove.go b/cli/cmd_remove.go index 112def682..a50ce6cda 100644 --- a/cli/cmd_remove.go +++ b/cli/cmd_remove.go @@ -34,6 +34,7 @@ may have changed, if that source or group was removed.`, $ sq rm @staging/sakila_db @staging/backup_db dev`, } + markCmdRequiresConfigLock(cmd) addTextFormatFlags(cmd) cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) cmd.Flags().BoolP(flag.Compact, flag.CompactShort, false, flag.CompactUsage) diff --git a/cli/cmd_scratch.go b/cli/cmd_scratch.go index c5dbf133c..a0e88ad35 100644 --- a/cli/cmd_scratch.go +++ b/cli/cmd_scratch.go @@ -36,6 +36,7 @@ importing non-SQL data, or cross-database joins. If no argument provided, get th source. Otherwise, set @HANDLE or an internal db as the scratch data source. The reserved handle "@scratch" resets the `, } + markCmdRequiresConfigLock(cmd) return cmd } diff --git a/cli/cmd_version.go b/cli/cmd_version.go index e037a1b93..483c66809 100644 --- a/cli/cmd_version.go +++ b/cli/cmd_version.go @@ -9,12 +9,13 @@ import ( "strings" "time" + "github.com/neilotoole/sq/cli/buildinfo" + "github.com/neilotoole/sq/cli/hostinfo" + "github.com/spf13/cobra" "golang.org/x/mod/semver" - "github.com/neilotoole/sq/cli/buildinfo" "github.com/neilotoole/sq/cli/flag" - "github.com/neilotoole/sq/cli/hostinfo" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 226346572..738709675 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -49,53 +49,42 @@ func newXLockSrcCmd() *cobra.Command { } func execXLockSrcCmd(cmd *cobra.Command, args []string) error { - if unlock, err := lockReloadConfig(cmd); err != nil { + ctx := cmd.Context() + ru := run.FromContext(ctx) + src, err := ru.Config.Collection.Get(args[0]) + if err != nil { return err - } else { - defer unlock() } - sleep := time.Second * 10 - fmt.Fprintf(os.Stdout, "huzzah, will sleep for %s\n", sleep) - time.Sleep(sleep) - return nil + timeout := time.Minute * 20 + lock, err := ru.Files.CacheLockFor(src) + if err != nil { + return err + } + fmt.Fprintf(ru.Out, "Locking cache for source %s with timeout %s for %q [%d]\n\n %s\n\n", + src.Handle, timeout, os.Args[0], os.Getpid(), lock) - //ctx := cmd.Context() - //ru := run.FromContext(ctx) - //src, err := ru.Config.Collection.Get(args[0]) - //if err != nil { - // return err - //} - // - //timeout := time.Minute * 20 - //lock, err := ru.Files.CacheLockFor(src) - //if err != nil { - // return err - //} - //fmt.Fprintf(ru.Out, "Locking cache for source %s with timeout %s for %q [%d]\n\n %s\n\n", - // src.Handle, timeout, os.Args[0], os.Getpid(), lock) - // - //err = lock.Lock(ctx, timeout) - //if err != nil { - // return err - //} - // - //fmt.Fprintf(ru.Out, "Cache lock acquired for %s\n", src.Handle) - // - //select { - //case <-pressEnter(): - // fmt.Fprintln(ru.Out, "\nENTER received, releasing lock") - //case <-ctx.Done(): - // fmt.Fprintln(ru.Out, "\nContext done, releasing lock") - //} - // - //fmt.Fprintf(ru.Out, "Releasing cache lock for %s\n", src.Handle) - //if err = lock.Unlock(); err != nil { - // return err - //} - // - //fmt.Fprintf(ru.Out, "Cache lock released for %s\n", src.Handle) - //return nil + err = lock.Lock(ctx, timeout) + if err != nil { + return err + } + + fmt.Fprintf(ru.Out, "Cache lock acquired for %s\n", src.Handle) + + select { + case <-pressEnter(): + fmt.Fprintln(ru.Out, "\nENTER received, releasing lock") + case <-ctx.Done(): + fmt.Fprintln(ru.Out, "\nContext done, releasing lock") + } + + fmt.Fprintf(ru.Out, "Releasing cache lock for %s\n", src.Handle) + if err = lock.Unlock(); err != nil { + return err + } + + fmt.Fprintf(ru.Out, "Cache lock released for %s\n", src.Handle) + return nil } func newXProgressCmd() *cobra.Command { @@ -149,15 +138,6 @@ func newXDownloadCmd() *cobra.Command { Hidden: true, Args: cobra.ExactArgs(1), RunE: execXDownloadCmd, - //RunE: func(cmd *cobra.Command, args []string) error { - // err1 := errz.New("inner huzzah") - // time.Sleep(time.Nanosecond) - // err2 := errz.Wrap(err1, "outer huzzah") - // time.Sleep(time.Nanosecond) - // err3 := errz.Wrap(err2, "outer huzzah") - // - // return err3 - //}, Example: ` $ sq x download https://sq.io/testdata/actor.csv # Download a big-ass file @@ -203,13 +183,6 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { case len(h.Errors) > 0: err1 := errz.Err(h.Errors[0]) return err1 - - //err1 := h.Errors[0] - //err2 := errz.New("another err") - //err3 := errz.Combine(err1, err2) - ////lg.FromContext(ctx).Error("OH NO", lga.Err, err3) - //return err3 - //return nil case len(h.WriteErrors) > 0: return h.WriteErrors[0] case len(h.CachedFiles) > 0: @@ -232,3 +205,29 @@ func pressEnter() <-chan struct{} { }() return done } + +func newXLockConfigCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "lock-config", + Short: "Test config lock", + Hidden: true, + Args: cobra.NoArgs, + ValidArgsFunction: completeHandle(1), + RunE: func(cmd *cobra.Command, args []string) error { + ru := run.FromContext(cmd.Context()) + fmt.Fprintf(ru.Out, "Locking config (pid %d)\n", os.Getpid()) + unlock, err := lockReloadConfig(cmd) + if err != nil { + return err + } + + fmt.Fprintln(ru.Out, "Config locked; ctrl-c to exit") + <-cmd.Context().Done() + unlock() + return nil + }, + Example: ` $ sq x lock-config`, + } + + return cmd +} diff --git a/cli/run.go b/cli/run.go index ce2b04af9..19c0cfd9e 100644 --- a/cli/run.go +++ b/cli/run.go @@ -241,11 +241,11 @@ func preRun(cmd *cobra.Command, ru *run.Run) error { return nil } + ctx := cmd.Context() if ru.Cleanup == nil { ru.Cleanup = cleanup.New() } - ctx := cmd.Context() // If the --output=/some/file flag is set, then we need to // override ru.Out (which is typically stdout) to point it at // the output destination file. @@ -277,7 +277,34 @@ func preRun(cmd *cobra.Command, ru *run.Run) error { } ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, ru.Cleanup, cmdOpts, ru.Out, ru.ErrOut) - return FinishRunInit(ctx, ru) + if err = FinishRunInit(ctx, ru); err != nil { + return err + } + + if cmdRequiresConfigLock(cmd) { + var unlock func() + if unlock, err = lockReloadConfig(cmd); err != nil { + return err + } + ru.Cleanup.Add(unlock) + } + return nil +} + +// markCmdRequiresConfigLock marks cmd as requiring a config lock. +// Thus, before the command's RunE is invoked, the config lock +// is acquired (in preRun), and released on cleanup. +func markCmdRequiresConfigLock(cmd *cobra.Command) { + if cmd.Annotations == nil { + cmd.Annotations = make(map[string]string) + } + cmd.Annotations["config.lock"] = "true" +} + +// cmdRequiresConfigLock returns true if markCmdRequiresConfigLock was +// previously invoked on cmd. +func cmdRequiresConfigLock(cmd *cobra.Command) bool { + return cmd.Annotations != nil && cmd.Annotations["config.lock"] == "true" } // lockReloadConfig acquires the lock for the config store, and updates the diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index b22a655f9..28d6c67f0 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -32,7 +32,7 @@ import ( // DebugDelay sleeps for a period of time to facilitate testing the // progress impl. It should be removed before release. // -// Deprecated: This is a temporary hack for testing. +// FIXME: Delete this before release. func DebugDelay() { const delay = time.Millisecond * 0 if delay > 0 { From 32c2f754e94c3112999318ae13e5e5b8c6169885 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 08:56:42 -0700 Subject: [PATCH 154/195] Fix multi err bugs --- cli/options_test.go | 2 +- cli/run.go | 10 +++++-- libsq/core/errz/internal_test.go | 14 +++++++++ libsq/core/errz/multi.go | 12 +++++--- libsq/core/ioz/lockfile/lockfile_test.go | 36 ------------------------ libsq/core/stringz/stringz_test.go | 2 +- 6 files changed, 31 insertions(+), 45 deletions(-) delete mode 100644 libsq/core/ioz/lockfile/lockfile_test.go diff --git a/cli/options_test.go b/cli/options_test.go index 31cc8fcd7..4319e7a3b 100644 --- a/cli/options_test.go +++ b/cli/options_test.go @@ -20,7 +20,7 @@ func TestRegisterDefaultOpts(t *testing.T) { log.Debug("options.Registry (after)", "reg", reg) keys := reg.Keys() - require.Len(t, keys, 37) + require.Len(t, keys, 46) for _, opt := range reg.Opts() { opt := opt diff --git a/cli/run.go b/cli/run.go index 19c0cfd9e..d010f1b22 100644 --- a/cli/run.go +++ b/cli/run.go @@ -311,15 +311,19 @@ func cmdRequiresConfigLock(cmd *cobra.Command) bool { // run (as found on cmd's context) with a fresh copy of the config, loaded // after lock acquisition. // +// The config lock should be acquired before making any changes to config. +// Timeout and progress options from ctx are honored. +// The caller is responsible for invoking the returned unlock func. +// Example usage: +// // if unlock, err := lockReloadConfig(cmd); err != nil { // return err // } else { // defer unlock() // } // -// The config lock should be acquired before making any changes to config. -// Timeout and progress options from ctx are honored. -// The caller is responsible for invoking the returned unlock func. +// However, in practice, most commands will invoke markCmdRequiresConfigLock +// instead of explicitly invoking lockReloadConfig. func lockReloadConfig(cmd *cobra.Command) (unlock func(), err error) { ctx := cmd.Context() ru := run.FromContext(ctx) diff --git a/libsq/core/errz/internal_test.go b/libsq/core/errz/internal_test.go index 27dadd73c..fb827a865 100644 --- a/libsq/core/errz/internal_test.go +++ b/libsq/core/errz/internal_test.go @@ -23,3 +23,17 @@ func TestAlienCause(t *testing.T) { cause = err.(*errz).alienCause() require.Equal(t, context.DeadlineExceeded, cause) } + +func TestAppendNilToMulti(t *testing.T) { + merr := Append(New("a"), New("b")) + _, ok := merr.(*multiErr) + require.True(t, ok) + + got := Append(merr, nil) + _, ok = got.(*multiErr) + require.True(t, ok) + + got = Append(nil, merr) + _, ok = got.(*multiErr) + require.True(t, ok) +} diff --git a/libsq/core/errz/multi.go b/libsq/core/errz/multi.go index db128daa6..faa53e326 100644 --- a/libsq/core/errz/multi.go +++ b/libsq/core/errz/multi.go @@ -462,17 +462,21 @@ func Append(left, right error) error { case left == nil && right == nil: return nil case left == nil: - if _, ok := right.(*errz); !ok { + switch right := right.(type) { + case *multiErr, *errz: + return right + default: // It's not an errz, so we need to wrap it. return &errz{stack: callers(0), error: right} } - return right case right == nil: - if _, ok := left.(*errz); !ok { + switch left := left.(type) { + case *multiErr, *errz: + return left + default: // It's not an errz, so we need to wrap it. return &errz{stack: callers(0), error: left} } - return left } if _, ok := right.(*multiErr); !ok { diff --git a/libsq/core/ioz/lockfile/lockfile_test.go b/libsq/core/ioz/lockfile/lockfile_test.go deleted file mode 100644 index 1ac0cf065..000000000 --- a/libsq/core/ioz/lockfile/lockfile_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package lockfile_test - -import ( - "context" - "path/filepath" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "github.com/neilotoole/sq/libsq/core/ioz/lockfile" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lgt" -) - -// FIXME: Duh, this can't work, because we're in the same pid. -func TestLockfile(t *testing.T) { - ctx := lg.NewContext(context.Background(), lgt.New(t)) - - pidfile := filepath.Join(t.TempDir(), "lock.pid") - lock, err := lockfile.New(pidfile) - require.NoError(t, err) - require.Equal(t, pidfile, string(lock)) - - require.NoError(t, lock.Lock(ctx, 0), - "should be able to acquire lock immediately") - time.AfterFunc(time.Second*100, func() { - require.NoError(t, lock.Unlock()) - }) - - err = lock.Lock(ctx, time.Second) - require.Error(t, err, "not enough time to acquire the lock") - - err = lock.Lock(ctx, time.Second*10) - require.NoError(t, err, "should be able to acquire the lock") -} diff --git a/libsq/core/stringz/stringz_test.go b/libsq/core/stringz/stringz_test.go index 021b41340..0a1a4d852 100644 --- a/libsq/core/stringz/stringz_test.go +++ b/libsq/core/stringz/stringz_test.go @@ -691,7 +691,7 @@ func TestSanitizeFilename(t *testing.T) { func TestTypeNames(t *testing.T) { errs := []error{errors.New("stdlib"), errz.New("errz")} names := stringz.TypeNames(errs...) - require.Equal(t, []string{"*errors.errorString", "*errz.fundamental"}, names) + require.Equal(t, []string{"*errors.errorString", "*errz.errz"}, names) a := []any{1, "hello", true, errs} names = stringz.TypeNames(a...) From fe905c981bfd55af980097628876ff5fe8f99456 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 08:59:56 -0700 Subject: [PATCH 155/195] linting --- libsq/core/errz/internal_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/libsq/core/errz/internal_test.go b/libsq/core/errz/internal_test.go index fb827a865..19609ec4b 100644 --- a/libsq/core/errz/internal_test.go +++ b/libsq/core/errz/internal_test.go @@ -24,6 +24,7 @@ func TestAlienCause(t *testing.T) { require.Equal(t, context.DeadlineExceeded, cause) } +//nolint:errorlint func TestAppendNilToMulti(t *testing.T) { merr := Append(New("a"), New("b")) _, ok := merr.(*multiErr) From 398e6ce2c22c0cb6e363998e4d35a21d16850905 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 09:28:58 -0700 Subject: [PATCH 156/195] bug fixes --- cli/cmd_sql_test.go | 4 ---- drivers/csv/csv_test.go | 9 --------- drivers/sqlite3/sqlite3.go | 4 ---- drivers/xlsx/ingest.go | 1 - libsq/core/ioz/ioz.go | 1 - libsq/core/progress/progress.go | 9 +++++---- libsq/core/record/record.go | 7 ------- testh/testh.go | 2 +- 8 files changed, 6 insertions(+), 31 deletions(-) diff --git a/cli/cmd_sql_test.go b/cli/cmd_sql_test.go index 3cff93231..0f611ff3b 100644 --- a/cli/cmd_sql_test.go +++ b/cli/cmd_sql_test.go @@ -202,8 +202,6 @@ func TestFlagActiveSource_sql(t *testing.T) { tr = testrun.New(ctx, t, tr) require.NoError(t, tr.Exec("add", proj.Abs(sakila.PathCSVActor), "--handle", "@csv")) - t.Logf("\n\n\n QUERY 1 \n\n\n") // FIXME: delete - tr = testrun.New(ctx, t, tr) require.NoError(t, tr.Exec( "sql", @@ -213,8 +211,6 @@ func TestFlagActiveSource_sql(t *testing.T) { )) require.Len(t, tr.BindCSV(), sakila.TblActorCount) - t.Logf("\n\n\n QUERY 2 \n\n\n") // FIXME: delete - // Now, use flag.ActiveSrc to switch the source. tr = testrun.New(ctx, t, tr) require.NoError(t, tr.Exec( diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index b486e727b..cbf79da10 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -331,12 +331,3 @@ func TestDatetime(t *testing.T) { }) } } - -// TestIngestLargeCSV generates a large CSV file. -// At count = 5,000,000, the generated file is ~500MB. -// This test is skipped by default. -// FIXME: Delete TestGenerateLargeCSV. -func TestGenerateLargeCSV(t *testing.T) { - t.Skip() - testh.GenerateLargeCSV(t, "testdata/payment-large.csv") -} diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index 9bc0d9116..da8f94518 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -146,10 +146,6 @@ func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, erro return nil, err } - if strings.Contains(fp, "checksum") { // FIXME: delete - lg.FromContext(ctx).Warn("This is bad") - } - db, err := sql.Open(dbDrvr, fp) if err != nil { return nil, errz.Wrapf(errw(err), "failed to open sqlite3 source with DSN: %s", fp) diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index d8e195230..29442f2bf 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -287,7 +287,6 @@ func buildSheetTables(ctx context.Context, srcIngestHeader *bool, sheets []*xShe sheetTbl, err := buildSheetTable(gCtx, srcIngestHeader, sheets[i]) if err != nil { if errz.Has[*driver.EmptyDataError](err) { - // if errz.IsErrNoData(err) { // FIXME: remove after testing // If the sheet has no data, we log it and skip it. lg.FromContext(ctx).Warn("Excel sheet has no data", laSheet, sheets[i].name, diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 48fff5bd1..961857e42 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -411,7 +411,6 @@ func (c *readCloserNotifier) Close() error { c.once.Do(func() { c.closeErr = c.ReadCloser.Close() c.fn(c.closeErr) - // c.closeErr = errz.New("huzzah") // FIXME: delete }) return c.closeErr } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 28d6c67f0..cd56482ab 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -115,6 +115,7 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors opts := []mpb.ContainerOption{ mpb.WithOutput(out), mpb.WithWidth(boxWidth), + // FIXME: switch back to auto refresh? // mpb.WithRefreshRate(refreshRate), mpb.WithManualRefresh(p.refreshCh), // mpb.WithAutoRefresh(), // Needed for color in Windows, apparently @@ -219,14 +220,14 @@ func (p *Progress) doStop() { defer lg.FromContext(p.ctx).Debug("Stopped progress widget") if p.pc == nil { close(p.stoppedCh) - close(p.refreshCh) + // close(p.refreshCh) p.cancelFn() return } if len(p.bars) == 0 { close(p.stoppedCh) - close(p.refreshCh) + // close(p.refreshCh) p.cancelFn() return } @@ -249,7 +250,7 @@ func (p *Progress) doStop() { p.refreshCh <- time.Now() close(p.stoppedCh) - close(p.refreshCh) + // close(p.refreshCh) p.pc.Wait() // Important: we must call cancelFn after pc.Wait() or the bars // may not be removed from the terminal. @@ -303,7 +304,7 @@ func (p *Progress) newBar(msg string, total int64, barStoppedCh: make(chan struct{}), } b.barInitFn = func() { - p.mu.Lock() // FIXME: not too sure about locking here? + p.mu.Lock() defer p.mu.Unlock() select { diff --git a/libsq/core/record/record.go b/libsq/core/record/record.go index 6313d6ef3..ecf8f0f16 100644 --- a/libsq/core/record/record.go +++ b/libsq/core/record/record.go @@ -29,13 +29,6 @@ import ( // It is an error for a Record to contain elements of any other type. type Record []any -// Value is the idealized generic type. One day, we'd like to be able -// to do something like this the below. -// FIXME: Delete this type. -type Value interface { - ~int64 | ~float64 | ~bool | ~string | ~[]byte | time.Time | decimal.Decimal -} - // Valid checks that each element of the record vals is // of an acceptable type. On the first unacceptable element, // the index of that element and an error are returned. On diff --git a/testh/testh.go b/testh/testh.go index 1531b665b..e299d5e0b 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -289,7 +289,7 @@ func (h *Helper) Source(handle string) *source.Source { // It might be expected that we would simply use the // collection (h.coll) to return the source, but this // method also uses a cache. This is because this - // method makes a copy the data file of file-based sources + // method makes a copy of the data file of file-based sources // as mentioned in the method godoc. h.coll = mustLoadCollection(h.Context, t) h.srcCache = map[string]*source.Source{} From 2de06bb3515ed1e85f7298ca8f0053d4ab1daebd Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 10:08:23 -0700 Subject: [PATCH 157/195] TestRun now creates unique cache dirs for Files --- cli/run.go | 17 +++++++++-------- cli/testrun/testrun.go | 12 ++++++++++++ libsq/source/cache.go | 2 ++ libsq/source/detect.go | 2 ++ 4 files changed, 25 insertions(+), 8 deletions(-) diff --git a/cli/run.go b/cli/run.go index d010f1b22..13a95e1df 100644 --- a/cli/run.go +++ b/cli/run.go @@ -146,6 +146,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { var err error if ru.Files == nil { + // The Files instance may already have been created. If not, create it. ru.Files, err = source.NewFiles(ctx, ru.OptionsRegistry, source.DefaultTempDir(), source.DefaultCacheDir(), true) if err != nil { lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run) @@ -192,34 +193,34 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { xmlud.Genre: xmlud.Import, } - for i, userDriverDef := range cfg.Ext.UserDrivers { - userDriverDef := userDriverDef + for i, udd := range cfg.Ext.UserDrivers { + udd := udd - errs := userdriver.ValidateDriverDef(userDriverDef) + errs := userdriver.ValidateDriverDef(udd) if len(errs) > 0 { err := errz.Combine(errs...) err = errz.Wrapf(err, "failed validation of user driver definition [%d] {%s} from config", - i, userDriverDef.Name) + i, udd.Name) return err } - importFn, ok := userDriverImporters[userDriverDef.Genre] + importFn, ok := userDriverImporters[udd.Genre] if !ok { return errz.Errorf("unsupported genre {%s} for user driver {%s} specified via config", - userDriverDef.Genre, userDriverDef.Name) + udd.Genre, udd.Name) } // For each user driver definition, we register a // distinct userdriver.Provider instance. udp := &userdriver.Provider{ Log: log, - DriverDef: userDriverDef, + DriverDef: udd, ImportFn: importFn, Ingester: ru.Grips, Files: ru.Files, } - ru.DriverRegistry.AddProvider(drivertype.Type(userDriverDef.Name), udp) + ru.DriverRegistry.AddProvider(drivertype.Type(udd.Name), udp) ru.Files.AddDriverDetectors(udp.Detectors()...) } diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index 12dd495b7..d001befda 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -109,6 +109,18 @@ func newRun(ctx context.Context, t testing.TB, cfgStore config.Store) (ru *run.R OptionsRegistry: optsReg, } + // The Files instance needs unique dirs for temp and cache because + // the test runs may execute in parallel inside the same test binary + // process, thus breaking the pid-based lockfile mechanism. + ru.Files, err = source.NewFiles( + ctx, + ru.OptionsRegistry, + filepath.Join(t.TempDir(), "sq", "temp"), + filepath.Join(t.TempDir(), "sq", "cache"), + true, + ) + require.NoError(t, err) + require.NoError(t, cli.FinishRunInit(ctx, ru)) return ru, out, errOut } diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 269439d7f..726617b15 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -61,6 +61,8 @@ func (fs *Files) CacheDirFor(src *Source) (dir string, err error) { return dir, nil } +// WriteIngestChecksum is invoked (after successful ingestion) to write the +// checksum for the ingest cache db. func (fs *Files) WriteIngestChecksum(ctx context.Context, src, backingSrc *Source) (err error) { log := lg.FromContext(ctx) ingestFilePath, err := fs.filepath(src) diff --git a/libsq/source/detect.go b/libsq/source/detect.go index 1c6422529..748a1df97 100644 --- a/libsq/source/detect.go +++ b/libsq/source/detect.go @@ -89,6 +89,8 @@ func (fs *Files) detectType(ctx context.Context, handle, loc string) (typ driver log := lg.FromContext(ctx).With(lga.Loc, loc) start := time.Now() + // FIXME: we could bypass newReader here for local files (that + // isn't @stdin). openFn := func(ctx context.Context) (io.ReadCloser, error) { src := &Source{Handle: handle, Location: loc} return fs.newReader(ctx, src) From b6902fe9104530525e5e745652ed77508afc762f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 10:15:41 -0700 Subject: [PATCH 158/195] fixed terminal_windows.go pkg --- cli/terminal_windows.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cli/terminal_windows.go b/cli/terminal_windows.go index 439009d04..98ecdd271 100644 --- a/cli/terminal_windows.go +++ b/cli/terminal_windows.go @@ -1,4 +1,4 @@ -package terminal +package cli import ( "io" From 8fe9f00a3978b2acd52a2c33ea152013c5465a44 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 10:50:47 -0700 Subject: [PATCH 159/195] Fixed bug with barInitFn --- libsq/core/progress/progress.go | 1 + 1 file changed, 1 insertion(+) diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index cd56482ab..ff5a4f62b 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -309,6 +309,7 @@ func (p *Progress) newBar(msg string, total int64, select { case <-p.ctx.Done(): + return case <-p.stoppedCh: return case <-b.barStoppedCh: From ed585832f80a60f410eed2878a88cd97368cd849 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 10:58:12 -0700 Subject: [PATCH 160/195] More terminal_windows.go issues --- cli/terminal_windows.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cli/terminal_windows.go b/cli/terminal_windows.go index 98ecdd271..447683c4c 100644 --- a/cli/terminal_windows.go +++ b/cli/terminal_windows.go @@ -3,6 +3,9 @@ package cli import ( "io" "os" + + "golang.org/x/sys/windows" + "golang.org/x/term" ) // isTerminal returns true if w is a terminal. From 44e0bae09853ccdf2535ac6296cfbd7f2a7344c8 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 13:57:23 -0700 Subject: [PATCH 161/195] files cache dir --- cli/testrun/testrun.go | 6 ++++-- drivers/xlsx/ingest.go | 1 - libsq/query_no_src_test.go | 4 ++++ libsq/source/files.go | 3 +-- libsq/source/files_test.go | 24 +++++++++++++++++++++--- testh/testh.go | 12 ++++++++++-- testh/tu/tu.go | 25 +++++++++++++++++++++---- 7 files changed, 61 insertions(+), 14 deletions(-) diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index d001befda..14c10d26a 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -12,6 +12,8 @@ import ( "sync" "testing" + "github.com/neilotoole/sq/testh/tu" + "github.com/stretchr/testify/require" "github.com/neilotoole/sq/cli" @@ -115,8 +117,8 @@ func newRun(ctx context.Context, t testing.TB, cfgStore config.Store) (ru *run.R ru.Files, err = source.NewFiles( ctx, ru.OptionsRegistry, - filepath.Join(t.TempDir(), "sq", "temp"), - filepath.Join(t.TempDir(), "sq", "cache"), + tu.TempDir(t, false), + tu.CacheDir(t, false), true, ) require.NoError(t, err) diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 29442f2bf..21e8ec06b 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -103,7 +103,6 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, x return err } - lg.FromContext(ctx).Error("count is woah", lga.Count, len(sheetTbls)) bar := progress.FromContext(ctx).NewUnitTotalCounter( "Ingesting sheets", "", diff --git a/libsq/query_no_src_test.go b/libsq/query_no_src_test.go index 38ca9b255..32cc246c9 100644 --- a/libsq/query_no_src_test.go +++ b/libsq/query_no_src_test.go @@ -12,6 +12,8 @@ import ( ) func TestQuery_no_source(t *testing.T) { + t.Parallel() + testCases := []struct { in string want string @@ -27,6 +29,8 @@ func TestQuery_no_source(t *testing.T) { for i, tc := range testCases { tc := tc t.Run(tu.Name(i, tc.in), func(t *testing.T) { + t.Parallel() + t.Logf("\nquery: %s\n want: %s", tc.in, tc.want) th := testh.New(t) coll := th.NewCollection() diff --git a/libsq/source/files.go b/libsq/source/files.go index 698371c0d..ec6259897 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -65,8 +65,7 @@ type Files struct { // NewFiles returns a new Files instance. If cleanFscache is true, the fscache // is cleaned on Files.Close. -func NewFiles(ctx context.Context, optReg *options.Registry, - tmpDir, cacheDir string, cleanFscache bool, +func NewFiles(ctx context.Context, optReg *options.Registry, tmpDir, cacheDir string, cleanFscache bool, ) (*Files, error) { log := lg.FromContext(ctx) log.Debug("Creating new Files instance", "tmp_dir", tmpDir, "cache_dir", cacheDir) diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index b6d54c064..c49d46550 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -56,7 +56,13 @@ func TestFiles_Type(t *testing.T) { t.Run(tu.Name(source.RedactLocation(tc.loc)), func(t *testing.T) { ctx := lg.NewContext(context.Background(), lgt.New(t)) - fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) + fs, err := source.NewFiles( + ctx, + nil, + tu.TempDir(t, true), + tu.CacheDir(t, true), + true, + ) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -98,7 +104,13 @@ func TestFiles_DetectType(t *testing.T) { t.Run(filepath.Base(tc.loc), func(t *testing.T) { ctx := lg.NewContext(context.Background(), lgt.New(t)) - fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) + fs, err := source.NewFiles( + ctx, + nil, + tu.TempDir(t, true), + tu.CacheDir(t, true), + true, + ) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -158,7 +170,13 @@ func TestFiles_NewReader(t *testing.T) { Location: proj.Abs(fpath), } - fs, err := source.NewFiles(ctx, nil, tu.TempDir(t), tu.CacheDir(t), true) + fs, err := source.NewFiles( + ctx, + nil, + tu.TempDir(t, true), + tu.CacheDir(t, true), + true, + ) require.NoError(t, err) g := &errgroup.Group{} diff --git a/testh/testh.go b/testh/testh.go index e299d5e0b..5ba7cc7af 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -14,6 +14,8 @@ import ( "testing" "time" + "github.com/neilotoole/sq/testh/tu" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -54,7 +56,6 @@ import ( "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/testsrc" - "github.com/neilotoole/sq/testh/tu" ) // defaultDBOpenTimeout is the timeout for tests to open (and ping) their DBs. @@ -159,7 +160,14 @@ func (h *Helper) init() { cfg := config.New() var err error - h.files, err = source.NewFiles(h.Context, optRegistry, tu.TempDir(h.T), tu.CacheDir(h.T), true) + + h.files, err = source.NewFiles( + h.Context, + optRegistry, + tu.TempDir(h.T, false), + tu.TempDir(h.T, false), + true, + ) require.NoError(h.T, err) h.Cleanup.Add(func() { diff --git a/testh/tu/tu.go b/testh/tu/tu.go index 5484a8ca0..8ed9ef7b0 100644 --- a/testh/tu/tu.go +++ b/testh/tu/tu.go @@ -209,6 +209,7 @@ func Name(args ...any) string { s = strings.ReplaceAll(s, "/", "_") s = strings.ReplaceAll(s, ":", "_") s = strings.ReplaceAll(s, `\`, "_") + s = stringz.SanitizeFilename(s) s = stringz.EllipsifyASCII(s, 28) // we don't want it to be too long parts = append(parts, s) } @@ -383,13 +384,29 @@ func MustAbsFilepath(elems ...string) string { } // TempDir is the standard means for obtaining a temp dir for tests. -func TempDir(t testing.TB) string { - return filepath.Join(t.TempDir(), "sq", "tmp") +// If arg clean is true, the temp dir is created via t.TempDir, and +// thus is deleted on test cleanup. +func TempDir(t testing.TB, clean bool) string { + if clean { + return filepath.Join(t.TempDir(), "sq-test", "tmp") + } + + fp := filepath.Join(os.TempDir(), "sq-test", stringz.Uniq8(), "tmp") + require.NoError(t, ioz.RequireDir(fp)) + return fp } // CacheDir is the standard means for obtaining a cache dir for tests. -func CacheDir(t testing.TB) string { - return filepath.Join(t.TempDir(), "sq", "cache") +// If arg clean is true, the cache dir is created via t.TempDir, and +// thus is deleted on test cleanup. +func CacheDir(t testing.TB, clean bool) string { + if clean { + return filepath.Join(t.TempDir(), "sq-test", "cache") + } + + fp := filepath.Join(os.TempDir(), "sq-test", stringz.Uniq8(), "cache") + require.NoError(t, ioz.RequireDir(fp)) + return fp } // ReadFileToString invokes ioz.ReadFileToString, failing t if From ffebf18a597cb7546a7959df4f45ea2347a0419a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 16:18:31 -0700 Subject: [PATCH 162/195] file close, db timeout --- drivers/csv/csv.go | 11 +------ drivers/json/json.go | 12 +------ drivers/mysql/mysql.go | 5 +++ drivers/postgres/postgres.go | 2 ++ drivers/sqlserver/sqlserver.go | 20 +++++++----- drivers/userdriver/userdriver.go | 15 +-------- drivers/xlsx/xlsx.go | 18 +--------- libsq/core/ioz/httpz/httpz.go | 5 +++ libsq/driver/driver_test.go | 7 +++- libsq/driver/opts.go | 2 +- libsq/source/download.go | 24 +++++++++----- libsq/source/files.go | 56 ++++++++++++++++++++++++++++---- 12 files changed, 99 insertions(+), 78 deletions(-) diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index 0b4bc6871..700399345 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -103,16 +103,7 @@ func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { // Ping implements driver.Driver. func (d *driveri) Ping(ctx context.Context, src *source.Source) error { - // FIXME: Does Ping calling d.files.Open cause a full read? - // We probably just want to check that the file exists - // or is accessible. - r, err := d.files.Open(ctx, src) - if err != nil { - return err - } - defer lg.WarnIfCloseError(d.log, lgm.CloseFileReader, r) - - return nil + return d.files.Ping(ctx, src) } // grip implements driver.Grip. diff --git a/drivers/json/json.go b/drivers/json/json.go index f3879e3a5..a40d80310 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -135,17 +135,7 @@ func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { // Ping implements driver.Driver. func (d *driveri) Ping(ctx context.Context, src *source.Source) error { - log := lg.FromContext(ctx).With(lga.Src, src) - log.Debug("Ping source") - - // FIXME: this should call d.files.Ping - r, err := d.files.Open(ctx, src) - if err != nil { - return err - } - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - - return nil + return d.files.Ping(ctx, src) } // grip implements driver.Grip. diff --git a/drivers/mysql/mysql.go b/drivers/mysql/mysql.go index 462b5f6f9..3e4e69e5c 100644 --- a/drivers/mysql/mysql.go +++ b/drivers/mysql/mysql.go @@ -486,6 +486,11 @@ func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, erro return nil, errw(err) } + cfg.Timeout = driver.OptConnOpenTimeout.Get(src.Options) + // REVISIT: Perhaps allow setting cfg.ReadTimeout and cfg.WriteTimeout? + // - https://github.com/go-sql-driver/mysql#writetimeout + // - https://github.com/go-sql-driver/mysql#readtimeout + if src.Schema != "" { lg.FromContext(ctx).Debug("Setting default schema for MysQL connection", lga.Src, src, diff --git a/drivers/postgres/postgres.go b/drivers/postgres/postgres.go index 39a98a12c..05fe37bb2 100644 --- a/drivers/postgres/postgres.go +++ b/drivers/postgres/postgres.go @@ -209,6 +209,8 @@ func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, erro })) } + dbCfg.ConnConfig.ConnectTimeout = driver.OptConnOpenTimeout.Get(src.Options) + db := stdlib.OpenDB(*dbCfg.ConnConfig, opts...) driver.ConfigureDB(ctx, db, src.Options) diff --git a/drivers/sqlserver/sqlserver.go b/drivers/sqlserver/sqlserver.go index f9336c945..d4ac77673 100644 --- a/drivers/sqlserver/sqlserver.go +++ b/drivers/sqlserver/sqlserver.go @@ -9,7 +9,7 @@ import ( "strconv" "strings" - "github.com/xo/dburl" + "github.com/microsoft/go-mssqldb/msdsn" "github.com/neilotoole/sq/libsq/ast" "github.com/neilotoole/sq/libsq/ast/render" @@ -175,15 +175,14 @@ func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, er func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, error) { log := lg.FromContext(ctx) loc := src.Location + cfg, err := msdsn.Parse(loc) + if err != nil { + return nil, errw(err) + } if src.Catalog != "" { - u, err := dburl.Parse(loc) - if err != nil { - return nil, errw(err) - } - vals := u.Query() - vals.Set("database", src.Catalog) - u.RawQuery = vals.Encode() - loc = u.String() + cfg.Database = src.Catalog + loc = cfg.URL().String() + log.Debug("Using catalog as database in connection string", lga.Src, src, lga.Catalog, src.Catalog, @@ -191,6 +190,9 @@ func (d *driveri) doOpen(ctx context.Context, src *source.Source) (*sql.DB, erro ) } + cfg.DialTimeout = driver.OptConnOpenTimeout.Get(src.Options) + loc = cfg.URL().String() + db, err := sql.Open(dbDrvr, loc) if err != nil { return nil, errw(err) diff --git a/drivers/userdriver/userdriver.go b/drivers/userdriver/userdriver.go index 3b029f93d..5bafb609c 100644 --- a/drivers/userdriver/userdriver.go +++ b/drivers/userdriver/userdriver.go @@ -119,20 +119,7 @@ func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { // Ping implements driver.Driver. func (d *driveri) Ping(ctx context.Context, src *source.Source) error { - d.log.Debug("Ping source", - lga.Driver, d.typ, - lga.Src, src, - ) - - r, err := d.files.Open(ctx, src) - if err != nil { - return err - } - - // TODO: possibly do something more useful than just - // getting the reader? - - return r.Close() + return d.files.Ping(ctx, src) } // grip implements driver.Grip. diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index 95f427977..fb9a91da1 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -108,23 +108,7 @@ func (d *Driver) ValidateSource(src *source.Source) (*source.Source, error) { // Ping implements driver.Driver. func (d *Driver) Ping(ctx context.Context, src *source.Source) (err error) { - log := lg.FromContext(ctx) - - r, err := d.files.Open(ctx, src) - if err != nil { - return err - } - - defer lg.WarnIfCloseError(log, lgm.CloseFileReader, r) - - f, err := excelize.OpenReader(r) - if err != nil { - return errz.Err(err) - } - - lg.WarnIfCloseError(log, lgm.CloseFileReader, f) - - return nil + return d.files.Ping(ctx, src) } func errw(err error) error { diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index 630c9ab1c..fc8fcab66 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -257,3 +257,8 @@ func fixPragmaCacheControl(header http.Header) { } func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } + +// StatusText is like http.StatusText, but also includes the code, e.g. "200/OK". +func StatusText(code int) string { + return strconv.Itoa(code) + "/" + http.StatusText(code) +} diff --git a/libsq/driver/driver_test.go b/libsq/driver/driver_test.go index b610ffddc..010cf7da6 100644 --- a/libsq/driver/driver_test.go +++ b/libsq/driver/driver_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "os" "testing" "github.com/stretchr/testify/assert" @@ -236,16 +237,20 @@ func TestDriver_Ping(t *testing.T) { testCases := sakila.AllHandles() testCases = append(testCases, sakila.CSVActor, sakila.CSVActorHTTP) + testCases = []string{sakila.CSVActorHTTP} + for _, handle := range testCases { handle := handle t.Run(handle, func(t *testing.T) { + t.Logf("pid: %d", os.Getpid()) tu.SkipShort(t, handle == sakila.XLSX) + tu.DiffOpenFileCount(t, true) + th := testh.New(t) src := th.Source(handle) drvr := th.DriverFor(src) - err := drvr.Ping(th.Context, src) require.NoError(t, err) }) diff --git a/libsq/driver/opts.go b/libsq/driver/opts.go index a35b58fd2..6421fa8d2 100644 --- a/libsq/driver/opts.go +++ b/libsq/driver/opts.go @@ -83,7 +83,7 @@ If n <= 0, connections are not closed due to a connection's age.`, "conn.open-timeout", "", 0, - time.Second*5, + time.Second*2, "Connection open timeout", "Max time to wait before a connection open timeout occurs.", options.TagSource, diff --git a/libsq/source/download.go b/libsq/source/download.go index 28adeaf14..c248e67ab 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -3,6 +3,7 @@ package source import ( "context" "io" + "net/http" "path/filepath" "time" @@ -69,13 +70,7 @@ func (fs *Files) downloadFor(ctx context.Context, src *Source) (*download.Downlo return nil, err } - o := options.Merge(options.FromContext(ctx), src.Options) - c := httpz.NewClient(httpz.DefaultUserAgent, - httpz.OptRequestTimeout(OptHTTPRequestTimeout.Get(o)), - httpz.OptResponseTimeout(OptHTTPResponseTimeout.Get(o)), - httpz.OptInsecureSkipVerify(OptHTTPSInsecureSkipVerify.Get(o)), - ) - + c := fs.httpClientFor(ctx, src) if dl, err = download.New(src.Handle, c, src.Location, dlDir); err != nil { return nil, err } @@ -83,6 +78,15 @@ func (fs *Files) downloadFor(ctx context.Context, src *Source) (*download.Downlo return dl, nil } +func (fs *Files) httpClientFor(ctx context.Context, src *Source) *http.Client { + o := options.Merge(options.FromContext(ctx), src.Options) + return httpz.NewClient(httpz.DefaultUserAgent, + httpz.OptRequestTimeout(OptHTTPRequestTimeout.Get(o)), + httpz.OptResponseTimeout(OptHTTPResponseTimeout.Get(o)), + httpz.OptInsecureSkipVerify(OptHTTPSInsecureSkipVerify.Get(o)), + ) +} + // downloadDirFor gets the download cache dir for src. It is not // guaranteed that the returned dir exists or is accessible. func (fs *Files) downloadDirFor(src *Source) (string, error) { @@ -169,7 +173,11 @@ func (fs *Files) openRemoteFile(ctx context.Context, src *Source, checkFresh boo }, } - go dl.Get(ctx, h) + fs.downloadsWg.Add(1) + go func() { + defer fs.downloadsWg.Done() + dl.Get(ctx, h) + }() select { case <-ctx.Done(): diff --git a/libsq/source/files.go b/libsq/source/files.go index ec6259897..c6af5a1b6 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -4,12 +4,15 @@ import ( "context" "io" "log/slog" + "net/http" "os" "path/filepath" "strconv" "sync" "time" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/fscache" @@ -49,6 +52,8 @@ type Files struct { // for that source. downloads map[string]*download.Download + downloadsWg *sync.WaitGroup + // fscache is used to cache files, providing convenient access // to multiple readers via Files.newReader. fscache *fscache.FSCache @@ -69,12 +74,6 @@ func NewFiles(ctx context.Context, optReg *options.Registry, tmpDir, cacheDir st ) (*Files, error) { log := lg.FromContext(ctx) log.Debug("Creating new Files instance", "tmp_dir", tmpDir, "cache_dir", cacheDir) - if tmpDir == "" { - return nil, errz.Errorf("tmpDir is empty") - } - if cacheDir == "" { - return nil, errz.Errorf("cacheDir is empty") - } if optReg == nil { optReg = &options.Registry{} @@ -88,6 +87,7 @@ func NewFiles(ctx context.Context, optReg *options.Registry, tmpDir, cacheDir st clnup: cleanup.New(), log: lg.FromContext(ctx), downloads: map[string]*download.Download{}, + downloadsWg: &sync.WaitGroup{}, } // We want a unique dir for each execution. Note that fcache is deleted @@ -386,9 +386,51 @@ func (fs *Files) newReader(ctx context.Context, src *Source) (io.ReadCloser, err return r, err } +// Ping implements a ping mechanism for document +// sources (local or remote files). +func (fs *Files) Ping(ctx context.Context, src *Source) error { + fs.mu.Lock() + defer fs.mu.Unlock() + + switch getLocType(src.Location) { + case locTypeLocalFile: + // It's a filepath + if _, err := os.Stat(src.Location); err != nil { + return errz.Wrapf(err, "ping: failed to stat file source %s: %s", src.Handle, src.Location) + } + return nil + case locTypeRemoteFile: + req, err := http.NewRequestWithContext(ctx, http.MethodHead, src.Location, nil) + if err != nil { + return errz.Wrapf(err, "ping: %s", src.Handle) + } + c := fs.httpClientFor(ctx, src) + resp, err := c.Do(req) //nolint:bodyclose + if err != nil { + return errz.Wrapf(err, "ping: %s", src.Handle) + } + defer lg.WarnIfCloseError(fs.log, lgm.CloseHTTPResponseBody, resp.Body) + if resp.StatusCode != http.StatusOK { + return errz.Errorf("ping: %s: expected %s but got %s", + src.Handle, httpz.StatusText(http.StatusOK), httpz.StatusText(resp.StatusCode)) + } + return nil + default: + return errz.Errorf("ping: unsupport location type for source %s: %s", src.Handle, src.RedactedLocation()) + } +} + // Close closes any open resources. func (fs *Files) Close() error { - fs.log.Debug("Files.Close invoked: executing clean funcs", lga.Count, fs.clnup.Len()) + fs.mu.Lock() + defer fs.mu.Unlock() + + // FIXME: should we use a timeout here while waiting for downloads? + fs.log.Debug("Files.Close: waiting any downloads to complete") + fs.downloadsWg.Wait() + + fs.log.Debug("Files.Close: executing clean funcs", lga.Count, fs.clnup.Len()) + return fs.clnup.Run() } From 92dad5c3455872156d8d0294be3e337d497db71e Mon Sep 17 00:00:00 2001 From: neilotoole Date: Thu, 11 Jan 2024 16:44:13 -0700 Subject: [PATCH 163/195] files stdin cache close --- CHANGELOG.md | 6 ++++++ libsq/driver/driver_test.go | 6 ------ libsq/source/download.go | 4 ++-- libsq/source/files.go | 16 +++++++++++----- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b00bcd647..35b127280 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Breaking changes are annotated with ☢️, and alpha/beta features with 🐥. +## Upcoming + +### Fixed + +- Open DB connection now correctly honors [`conn.open-timeout`](https://sq.io/docs/config#connopen-timeout). + ## [v0.46.1] - 2023-12-06 ### Fixed diff --git a/libsq/driver/driver_test.go b/libsq/driver/driver_test.go index 010cf7da6..01b7f883b 100644 --- a/libsq/driver/driver_test.go +++ b/libsq/driver/driver_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "os" "testing" "github.com/stretchr/testify/assert" @@ -237,17 +236,12 @@ func TestDriver_Ping(t *testing.T) { testCases := sakila.AllHandles() testCases = append(testCases, sakila.CSVActor, sakila.CSVActorHTTP) - testCases = []string{sakila.CSVActorHTTP} - for _, handle := range testCases { handle := handle t.Run(handle, func(t *testing.T) { - t.Logf("pid: %d", os.Getpid()) tu.SkipShort(t, handle == sakila.XLSX) - tu.DiffOpenFileCount(t, true) - th := testh.New(t) src := th.Source(handle) drvr := th.DriverFor(src) diff --git a/libsq/source/download.go b/libsq/source/download.go index c248e67ab..fbd050978 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -173,9 +173,9 @@ func (fs *Files) openRemoteFile(ctx context.Context, src *Source, checkFresh boo }, } - fs.downloadsWg.Add(1) + fs.fillerWgs.Add(1) go func() { - defer fs.downloadsWg.Done() + defer fs.fillerWgs.Done() dl.Get(ctx, h) }() diff --git a/libsq/source/files.go b/libsq/source/files.go index c6af5a1b6..8b9710374 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -52,7 +52,9 @@ type Files struct { // for that source. downloads map[string]*download.Download - downloadsWg *sync.WaitGroup + // fillerWgs is used to wait for asynchronous filling of the cache + // to complete (including downloads). + fillerWgs *sync.WaitGroup // fscache is used to cache files, providing convenient access // to multiple readers via Files.newReader. @@ -87,7 +89,7 @@ func NewFiles(ctx context.Context, optReg *options.Registry, tmpDir, cacheDir st clnup: cleanup.New(), log: lg.FromContext(ctx), downloads: map[string]*download.Download{}, - downloadsWg: &sync.WaitGroup{}, + fillerWgs: &sync.WaitGroup{}, } // We want a unique dir for each execution. Note that fcache is deleted @@ -234,10 +236,12 @@ func (fs *Files) addStdin(ctx context.Context, handle string, f *os.File) error } fs.fscacheEntryMetas[handle] = entryMeta + fs.fillerWgs.Add(1) start := time.Now() pw := progress.NewWriter(ctx, "Reading "+handle, -1, cacheWrtr) ioz.CopyAsync(pw, contextio.NewReader(ctx, f), func(written int64, err error) { + defer fs.fillerWgs.Done() defer lg.WarnIfCloseError(log, lgm.CloseFileReader, f) entryMeta.written = written entryMeta.err = err @@ -394,11 +398,11 @@ func (fs *Files) Ping(ctx context.Context, src *Source) error { switch getLocType(src.Location) { case locTypeLocalFile: - // It's a filepath if _, err := os.Stat(src.Location); err != nil { return errz.Wrapf(err, "ping: failed to stat file source %s: %s", src.Handle, src.Location) } return nil + case locTypeRemoteFile: req, err := http.NewRequestWithContext(ctx, http.MethodHead, src.Location, nil) if err != nil { @@ -415,8 +419,10 @@ func (fs *Files) Ping(ctx context.Context, src *Source) error { src.Handle, httpz.StatusText(http.StatusOK), httpz.StatusText(resp.StatusCode)) } return nil + default: - return errz.Errorf("ping: unsupport location type for source %s: %s", src.Handle, src.RedactedLocation()) + // Shouldn't happen + return errz.Errorf("ping: %s is not a document source", src.Handle) } } @@ -427,7 +433,7 @@ func (fs *Files) Close() error { // FIXME: should we use a timeout here while waiting for downloads? fs.log.Debug("Files.Close: waiting any downloads to complete") - fs.downloadsWg.Wait() + fs.fillerWgs.Wait() fs.log.Debug("Files.Close: executing clean funcs", lga.Count, fs.clnup.Len()) From 1a231f62bc7eddf392eef9f59d7d90c74391a542 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 12 Jan 2024 05:30:29 -0700 Subject: [PATCH 164/195] Minor cleanup --- libsq/source/cache.go | 10 +++++++--- libsq/source/files.go | 10 +++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 726617b15..b6fc2464d 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -195,7 +195,9 @@ func (fs *Files) CachePaths(src *Source) (srcCacheDir, cacheDB, checksums string // sourceHash generates a hash for src. The hash is based on the // member fields of src, with special handling for src.Options. // Only the opts that affect data ingestion (options.TagIngestMutate) -// are incorporated in the hash. +// are incorporated in the hash. The generated hash is used to +// determine the cache dir for src. Thus, if a source is mutated +// (e.g. the remote http location changes), a new cache dir results. func (fs *Files) sourceHash(src *Source) string { if src == nil { return "" @@ -208,9 +210,7 @@ func (fs *Files) sourceHash(src *Source) string { buf.WriteString(src.Catalog) buf.WriteString(src.Schema) - // FIXME: Revisit this mUsedKeys := make(map[string]any) - if src.Options != nil { keys := src.Options.Keys() for _, k := range keys { @@ -292,6 +292,10 @@ func (fs *Files) CacheClear(ctx context.Context) error { // CacheSweep sweeps the cache dir, making a best-effort attempt // to remove any empty directories. Note that this operation is // distinct from [Files.CacheClear]. +// +// REVISIT: This doesn't really do anything useful. It should instead +// sweep any abandoned cache dirs, i.e. cache dirs that don't have +// an associated source. func (fs *Files) CacheSweep(ctx context.Context) { fs.mu.Lock() defer fs.mu.Unlock() diff --git a/libsq/source/files.go b/libsq/source/files.go index 8b9710374..0e359739b 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -187,7 +187,8 @@ type fscacheEntryMeta struct { // AddStdin copies f to fs's cache: the stdin data in f // is later accessible via fs.Open(src) where src.Handle // is StdinHandle; f's type can be detected via DetectStdinType. -// Note that f is closed by this method. +// Note that f is ultimately closed by a goroutine spawned by +// this method, but may not be closed at the time of return. func (fs *Files) AddStdin(ctx context.Context, f *os.File) error { fs.mu.Lock() defer fs.mu.Unlock() @@ -426,17 +427,16 @@ func (fs *Files) Ping(ctx context.Context, src *Source) error { } } -// Close closes any open resources. +// Close closes any open resources and waits for any goroutines +// to complete. func (fs *Files) Close() error { fs.mu.Lock() defer fs.mu.Unlock() - // FIXME: should we use a timeout here while waiting for downloads? - fs.log.Debug("Files.Close: waiting any downloads to complete") + fs.log.Debug("Files.Close: waiting for goroutines to complete") fs.fillerWgs.Wait() fs.log.Debug("Files.Close: executing clean funcs", lga.Count, fs.clnup.Len()) - return fs.clnup.Run() } From e0c6c24b78ebdc528c010ede2caa960e297eeebf Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 12 Jan 2024 08:45:27 -0700 Subject: [PATCH 165/195] Switch to using log.format instead of log.devmode --- CHANGELOG.md | 46 +++++++++++++++- cli/cmd_add.go | 3 +- cli/cmd_cache.go | 2 +- cli/cmd_root.go | 19 +++++++ cli/cmd_version.go | 5 +- cli/cmd_x.go | 5 +- cli/complete.go | 10 ++-- cli/config/config.go | 3 ++ cli/config/yamlstore/yamlstore.go | 3 +- cli/diff/table.go | 3 +- cli/flag/flag.go | 3 ++ cli/flags.go | 2 + cli/logging.go | 87 ++++++++++++++++++++++++++---- cli/options.go | 5 +- cli/output/format/opt.go | 5 ++ cli/run.go | 3 +- cli/terminal_windows.go | 1 - cli/testrun/testrun.go | 3 +- drivers/mysql/errors.go | 3 +- drivers/postgres/errors.go | 3 +- drivers/sqlite3/errors.go | 3 +- drivers/sqlserver/errors.go | 3 +- libsq/core/ioz/httpz/httpz.go | 1 - libsq/core/lg/lga/lga.go | 1 + libsq/core/stringz/stringz_test.go | 3 +- libsq/driver/grips.go | 3 +- libsq/source/cache.go | 13 ++--- libsq/source/download.go | 7 ++- libsq/source/files.go | 7 ++- testh/testh.go | 3 +- 30 files changed, 192 insertions(+), 66 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35b127280..4eb4a4f96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,9 +9,53 @@ Breaking changes are annotated with ☢️, and alpha/beta features with 🐥. ## Upcoming +### Added + +- Long-running operations (such as data ingestion, or file download) now result + in a progress bar being displayed. Display of the progress bar is controlled + by the new config options [`progress`](https://sq.io/docs/config#progress) + and [`progress.delay`](https://sq.io/docs/config#progressdelay). +- Ingested [document sources](https://sq.io/docs/concepts#document-source) (such as + [CSV](https://sq.io/docs/drivers/csv) or [Excel](https://sq.io/docs/drivers/xlsx)) + now make use of an ingest cache DB. Previously, ingestion of document source data occurred + on each `sq` command. It is now a one-time cost; subsequent use of the document source utilizes + the cache DB. If the source document changes, the ingest cache DB is invalidated and + ingested again. This is a massively improved experience for large document sources. +- There's several new commands to interact with the cache. + - [`sq cache enable`](https://sq.io/docs/cmd/cache_enable) and + [`sq cache disable`](https://sq.io/docs/cmd/cache_disable) control cache usage. + You can also instead use the new [`ingest.cache`](https://sq.io/docs/config#ingestcache) + config option. + - [`sq cache clear`](https://sq.io/docs/cmd/cache_clear) clears the cache. + - [`sq cache location`](https://sq.io/docs/cmd/cache_location) prints the cache location on disk. + - [`sq cache stat`](https://sq.io/docs/cmd/cache_stat) shows stats about the cache. + - [`sq cache tree`](https://sq.io/docs/cmd/cache_location) shows a tree view of the cache. +- Downloading of remote document sources (e.g. a CSV file at + [`https://sq.io/testdata/actor.csv`](https://sq.io/testdata/actor.csv)) has been completely + overhauled. Previously, `sq` would re-download the remote file on every command. Now, the + remote file is downloaded and cached locally. Subsequent commands check for staleness of + the cached download, and re-download if necessary. +- As part of the download revamp, new config options have been introduced: + - [`http.request.timeout`](https://sq.io/docs/config#httprequesttimeout) and + [`http.response.timeout`](https://sq.io/docs/config#httpresponsetimeout) control HTTP timeout. + - [`https.insecure-skip-verify`](https://sq.io/docs/config#httpsinsecureskipverify) controls + whether HTTPS connections verify the server's certificate. This is useful for remote files served + with a self-signed certificate. +- There are two more new config options introduced as part of the above work. + - [`cache.lock.timeout`](https://sq.io/docs/config#cachelocktimeout) controls the time that + `sq` will wait for a lock on the cache DB. The cache lock is introduced for when you have + multiple `sq` commands running concurrently, and you want to avoid them stepping on each other. + - Similarly, [`config.lock.timeout`](https://sq.io/docs/config#configlocktimeout) controls the + timeout for acquiring the (newly-introduced) lock on `sq`'s config file. This helps prevent + issues with multiple `sq` processes mutating the config concurrently. +- `sq`'s own [logs](https://sq.io/docs/config#logging) previously outputted in JSON + format. Now there's a new [`log.format`](https://sq.io/docs/config#logformat) config option + that permits setting the log format to `json` or `text`. The `text` format is more human-friendly, and + is now the default. + ### Fixed -- Open DB connection now correctly honors [`conn.open-timeout`](https://sq.io/docs/config#connopen-timeout). +- Opening a DB connection now correctly honors [`conn.open-timeout`](https://sq.io/docs/config#connopen-timeout). ## [v0.46.1] - 2023-12-06 diff --git a/cli/cmd_add.go b/cli/cmd_add.go index 19df2e5fe..a6c8ca09d 100644 --- a/cli/cmd_add.go +++ b/cli/cmd_add.go @@ -8,8 +8,6 @@ import ( "os" "strings" - "github.com/neilotoole/sq/libsq/core/stringz" - "github.com/spf13/cobra" "golang.org/x/term" @@ -21,6 +19,7 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" ) diff --git a/cli/cmd_cache.go b/cli/cmd_cache.go index 79bf1c819..63d2dacc5 100644 --- a/cli/cmd_cache.go +++ b/cli/cmd_cache.go @@ -147,7 +147,7 @@ func newCacheEnableCmd() *cobra.Command { cmd := &cobra.Command{ Use: "enable", Short: "Enable caching", - Long: `Disable caching. This is equivalent to: + Long: `Enable caching. This is equivalent to: $ sq config set ingest.cache true`, Args: cobra.ExactArgs(0), diff --git a/cli/cmd_root.go b/cli/cmd_root.go index 9978c8bd3..5b4abf856 100644 --- a/cli/cmd_root.go +++ b/cli/cmd_root.go @@ -1,9 +1,12 @@ package cli import ( + "log/slog" + "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/cli/output/format" _ "github.com/neilotoole/sq/drivers" // Load drivers ) @@ -103,7 +106,23 @@ See docs and more: https://sq.io`, cmd.PersistentFlags().String(flag.Config, "", flag.ConfigUsage) cmd.PersistentFlags().Bool(flag.LogEnabled, false, flag.LogEnabledUsage) + panicOn(cmd.RegisterFlagCompletionFunc(flag.LogEnabled, completeBool)) cmd.PersistentFlags().String(flag.LogFile, "", flag.LogFileUsage) + cmd.PersistentFlags().String(flag.LogLevel, "", flag.LogLevelUsage) + panicOn(cmd.RegisterFlagCompletionFunc(flag.LogLevel, completeStrings( + 1, + slog.LevelDebug.String(), + slog.LevelInfo.String(), + slog.LevelWarn.String(), + slog.LevelError.String(), + ))) + + cmd.PersistentFlags().String(flag.LogFormat, "", flag.LogFormatUsage) + panicOn(cmd.RegisterFlagCompletionFunc(flag.LogFormat, completeStrings( + 1, + string(format.Text), + string(format.JSON), + ))) return cmd } diff --git a/cli/cmd_version.go b/cli/cmd_version.go index 483c66809..e037a1b93 100644 --- a/cli/cmd_version.go +++ b/cli/cmd_version.go @@ -9,13 +9,12 @@ import ( "strings" "time" - "github.com/neilotoole/sq/cli/buildinfo" - "github.com/neilotoole/sq/cli/hostinfo" - "github.com/spf13/cobra" "golang.org/x/mod/semver" + "github.com/neilotoole/sq/cli/buildinfo" "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/cli/hostinfo" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 738709675..a8b2a73b8 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -7,12 +7,11 @@ import ( "os" "time" - "github.com/neilotoole/sq/cli/flag" - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/spf13/cobra" + "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/run" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/sq/libsq/core/ioz/httpz" diff --git a/cli/complete.go b/cli/complete.go index abd0218cd..dc118c430 100644 --- a/cli/complete.go +++ b/cli/complete.go @@ -2,6 +2,7 @@ package cli import ( "context" + "log/slog" "slices" "strings" "time" @@ -45,7 +46,7 @@ var ( ) // completeStrings completes from a slice of string. -func completeStrings(max int, a ...string) completionFunc { //nolint:unparam +func completeStrings(max int, a ...string) completionFunc { return func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { if max > 0 && len(args) >= max { return nil, cobra.ShellCompDirectiveNoFileComp @@ -238,11 +239,12 @@ func completeOptValue(cmd *cobra.Command, args []string, toComplete string) ([]s } case LogLevelOpt: - a = []string{"debug", "DEBUG", "info", "INFO", "warn", "WARN", "error", "ERROR"} + a = []string{slog.LevelDebug.String(), slog.LevelInfo.String(), slog.LevelWarn.String(), slog.LevelError.String()} case format.Opt: - if opt.Key() == OptErrorFormat.Key() { + switch opt.Key() { + case OptErrorFormat.Key(), OptLogFormat.Key(): a = []string{string(format.Text), string(format.JSON)} - } else { + default: a = stringz.Strings(format.All()) } case options.Bool: diff --git a/cli/config/config.go b/cli/config/config.go index f16375f44..667e07d3b 100644 --- a/cli/config/config.go +++ b/cli/config/config.go @@ -19,6 +19,9 @@ const ( // EnvarLogLevel is the log level. It maps to a slog.Level. EnvarLogLevel = "SQ_LOG_LEVEL" + // EnvarLogFormat is the log format. It maps to a slog.Level. + EnvarLogFormat = "SQ_LOG_FORMAT" + // EnvarLogEnabled turns logging on or off. EnvarLogEnabled = "SQ_LOG" diff --git a/cli/config/yamlstore/yamlstore.go b/cli/config/yamlstore/yamlstore.go index b46fa190f..66d59eecb 100644 --- a/cli/config/yamlstore/yamlstore.go +++ b/cli/config/yamlstore/yamlstore.go @@ -8,12 +8,11 @@ import ( "os" "path/filepath" - "github.com/neilotoole/sq/libsq/core/ioz/lockfile" - "github.com/neilotoole/sq/cli/buildinfo" "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/options" diff --git a/cli/diff/table.go b/cli/diff/table.go index 2d9954a50..57aa2447c 100644 --- a/cli/diff/table.go +++ b/cli/diff/table.go @@ -4,14 +4,13 @@ import ( "context" "fmt" - "github.com/neilotoole/sq/libsq/driver" - "golang.org/x/sync/errgroup" udiff "github.com/neilotoole/sq/cli/diff/internal/go-udiff" "github.com/neilotoole/sq/cli/diff/internal/go-udiff/myers" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/metadata" ) diff --git a/cli/flag/flag.go b/cli/flag/flag.go index b99a459a4..6e7d34a6b 100644 --- a/cli/flag/flag.go +++ b/cli/flag/flag.go @@ -175,6 +175,9 @@ const ( LogLevel = "log.level" LogLevelUsage = "Log level: one of DEBUG, INFO, WARN, ERROR" + LogFormat = "log.format" + LogFormatUsage = `Log format: one of "text" or "json"` + DiffOverview = "overview" DiffOverviewShort = "O" DiffOverviewUsage = "Compare source overview" diff --git a/cli/flags.go b/cli/flags.go index a92fb89ce..af8c8db01 100644 --- a/cli/flags.go +++ b/cli/flags.go @@ -71,6 +71,8 @@ func cmdFlagBool(cmd *cobra.Command, name string) bool { // getBootstrapFlagValue parses osArgs looking for flg. The flag is always // treated as string. This function exists because some components, such // as logging and config, interrogate flags before cobra has loaded. +// +//nolint:unparam func getBootstrapFlagValue(flg, flgShort, flgUsage string, osArgs []string) (val string, ok bool, err error) { fs := pflag.NewFlagSet("bootstrap", pflag.ContinueOnError) fs.ParseErrorsWhitelist.UnknownFlags = true diff --git a/cli/logging.go b/cli/logging.go index 95764fbb9..d544a1aa0 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -2,6 +2,7 @@ package cli import ( "context" + "fmt" "io" "log/slog" "os" @@ -13,6 +14,7 @@ import ( "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/cli/output/format" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/devlog" @@ -50,14 +52,21 @@ var ( "Log level, one of: DEBUG, INFO, WARN, ERROR.", ) - OptLogDevMode = options.NewBool( - "log.devmode", + OptLogFormat = format.NewOpt( + "log.format", "", - false, 0, - false, - "Log in devmode", - "Log in devmode.", + format.Text, + func(f format.Format) error { + if f == format.Text || f == format.JSON { + return nil + } + + return errz.Errorf("option {log.format} allows only %q or %q", format.Text, format.JSON) + }, + "Log output format", + fmt.Sprintf( + `Log output format. Allowed formats are %q (human-friendly) or %q.`, format.Text, format.JSON), ) ) @@ -102,15 +111,23 @@ func defaultLogging(ctx context.Context, osArgs []string, cfg *config.Config, } closer = logFile.Close - devMode := OptLogDevMode.Get(cfg.Options) + // Determine if we're logging dev mode (format.Text). + devMode := OptLogFormat.Default() != format.JSON + switch getLogFormat(ctx, osArgs, cfg) { //nolint:exhaustive + case format.Text: + devMode = true + case format.JSON: + devMode = false + default: + // Shouldn't happen + } if devMode { h = devlog.NewHandler(logFile, lvl) } else { h = newJSONHandler(logFile, lvl) } - // h = devlog.NewHandler(logFile, lvl) - // h = newJSONHandler(logFile, lvl) + return slog.New(h), h, closer, nil } @@ -281,6 +298,58 @@ func getLogLevel(ctx context.Context, osArgs []string, cfg *config.Config) slog. return lvl } +// getLogEnabled gets the log format based on flags, envars, or config. +// Any error is logged to the ctx logger. The returned value is guaranteed +// to be one of [format.Text] or [format.JSON]. +func getLogFormat(ctx context.Context, osArgs []string, cfg *config.Config) format.Format { + bootLog := lg.FromContext(ctx) + + val, ok, err := getBootstrapFlagValue(flag.LogFormat, "", flag.LogFormatUsage, osArgs) + if err != nil { + bootLog.Error("Error reading log format from flag", lga.Flag, flag.LogFormat, lga.Err, err) + } + if ok { + bootLog.Debug("Using log format specified via flag", lga.Flag, flag.LogFormat, lga.Val, val) + + f := new(format.Format) + if err = f.UnmarshalText([]byte(val)); err == nil { + switch *f { //nolint:exhaustive + case format.Text, format.JSON: + return *f + default: + } + } + bootLog.Error("Invalid log format specified via flag", + lga.Flag, flag.LogFormat, lga.Val, val, lga.Err, err) + } + + val, ok = os.LookupEnv(config.EnvarLogFormat) + if ok { + bootLog.Debug("Using log level specified via envar", + lga.Env, config.EnvarLogFormat, lga.Val, val) + + f := new(format.Format) + if err = f.UnmarshalText([]byte(val)); err == nil { + switch *f { //nolint:exhaustive + case format.Text, format.JSON: + return *f + default: + } + } + bootLog.Error("Invalid log format specified by envar", + lga.Env, config.EnvarLogLevel, lga.Val, val, lga.Err, err) + } + + var o options.Options + if cfg != nil { + o = cfg.Options + } + + f := OptLogFormat.Get(o) + bootLog.Debug("Using log format specified via config", lga.Key, OptLogFormat.Key(), lga.Val, f) + return f +} + // getLogFilePath gets the log file path, based on flags, envars, or config. // If a log file is not specified (and thus logging is disabled), empty string // is returned. diff --git a/cli/options.go b/cli/options.go index a6f24c907..d1659503d 100644 --- a/cli/options.go +++ b/cli/options.go @@ -4,12 +4,11 @@ import ( "fmt" "strings" - "github.com/neilotoole/sq/cli/config" - "github.com/samber/lo" "github.com/spf13/cobra" "github.com/spf13/pflag" + "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/cli/output/xlsxw" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/drivers/csv" @@ -175,7 +174,7 @@ func RegisterDefaultOpts(reg *options.Registry) { OptLogEnabled, OptLogFile, OptLogLevel, - OptLogDevMode, + OptLogFormat, OptDiffNumLines, OptDiffDataFormat, source.OptHTTPRequestTimeout, diff --git a/cli/output/format/opt.go b/cli/output/format/opt.go index f2ff496dd..b7d6765de 100644 --- a/cli/output/format/opt.go +++ b/cli/output/format/opt.go @@ -77,6 +77,11 @@ func (op Opt) GetAny(o options.Options) any { return op.Get(o) } +// Default returns the default value of op. +func (op Opt) Default() Format { + return op.defaultVal +} + // DefaultAny implements options.Opt. func (op Opt) DefaultAny() any { return op.defaultVal diff --git a/cli/run.go b/cli/run.go index 13a95e1df..970f14b3c 100644 --- a/cli/run.go +++ b/cli/run.go @@ -8,8 +8,6 @@ import ( "path/filepath" "time" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/config" @@ -32,6 +30,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/slogbuf" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" diff --git a/cli/terminal_windows.go b/cli/terminal_windows.go index 447683c4c..9a33dd6d0 100644 --- a/cli/terminal_windows.go +++ b/cli/terminal_windows.go @@ -4,7 +4,6 @@ import ( "io" "os" - "golang.org/x/sys/windows" "golang.org/x/term" ) diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index 14c10d26a..85b56eea9 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -12,8 +12,6 @@ import ( "sync" "testing" - "github.com/neilotoole/sq/testh/tu" - "github.com/stretchr/testify/require" "github.com/neilotoole/sq/cli" @@ -25,6 +23,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/testh/tu" ) // TestRun is a helper for testing sq commands. diff --git a/drivers/mysql/errors.go b/drivers/mysql/errors.go index 93e75659d..eae0738fe 100644 --- a/drivers/mysql/errors.go +++ b/drivers/mysql/errors.go @@ -3,11 +3,10 @@ package mysql import ( "errors" - "github.com/neilotoole/sq/libsq/driver" - "github.com/go-sql-driver/mysql" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/driver" ) // errw wraps any error from the db. It should be called at diff --git a/drivers/postgres/errors.go b/drivers/postgres/errors.go index 1baa1674c..760cdb1c8 100644 --- a/drivers/postgres/errors.go +++ b/drivers/postgres/errors.go @@ -3,11 +3,10 @@ package postgres import ( "errors" - "github.com/neilotoole/sq/libsq/driver" - "github.com/jackc/pgx/v5/pgconn" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/driver" ) // errw wraps any error from the db. It should be called at diff --git a/drivers/sqlite3/errors.go b/drivers/sqlite3/errors.go index 8ea505655..b4186c959 100644 --- a/drivers/sqlite3/errors.go +++ b/drivers/sqlite3/errors.go @@ -3,9 +3,8 @@ package sqlite3 import ( "strings" - "github.com/neilotoole/sq/libsq/driver" - "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/driver" ) // errw wraps any error from the db. It should be called at diff --git a/drivers/sqlserver/errors.go b/drivers/sqlserver/errors.go index 50c174bd9..4b626a698 100644 --- a/drivers/sqlserver/errors.go +++ b/drivers/sqlserver/errors.go @@ -3,11 +3,10 @@ package sqlserver import ( "errors" - "github.com/neilotoole/sq/libsq/driver" - mssql "github.com/microsoft/go-mssqldb" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/driver" ) // mssql error codes diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index fc8fcab66..827a74d15 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -24,7 +24,6 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/stringz" ) diff --git a/libsq/core/lg/lga/lga.go b/libsq/core/lg/lga/lga.go index 610ac1e87..7ec05b8e4 100644 --- a/libsq/core/lg/lga/lga.go +++ b/libsq/core/lg/lga/lga.go @@ -21,6 +21,7 @@ const ( Dest = "dest" Dir = "dir" Driver = "driver" + Default = "default" DefaultTo = "default_to" Elapsed = "elapsed" Env = "env" diff --git a/libsq/core/stringz/stringz_test.go b/libsq/core/stringz/stringz_test.go index 0a1a4d852..d9c5f070c 100644 --- a/libsq/core/stringz/stringz_test.go +++ b/libsq/core/stringz/stringz_test.go @@ -6,13 +6,12 @@ import ( "strings" "testing" - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/samber/lo" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/testh/tu" ) diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index a11355d43..4fd2865fe 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -6,8 +6,6 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" @@ -15,6 +13,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" ) diff --git a/libsq/source/cache.go b/libsq/source/cache.go index b6fc2464d..338b6d65a 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -9,19 +9,16 @@ import ( "strings" "time" - "github.com/neilotoole/sq/libsq/core/ioz/lockfile" - - "github.com/neilotoole/sq/libsq/core/lg/lgm" - + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/source/drivertype" - - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/neilotoole/sq/libsq/source/drivertype" ) // OptCacheLockTimeout is the time allowed to acquire a cache lock. diff --git a/libsq/source/download.go b/libsq/source/download.go index fbd050978..05d326c36 100644 --- a/libsq/source/download.go +++ b/libsq/source/download.go @@ -14,7 +14,6 @@ import ( "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" - "github.com/neilotoole/sq/libsq/core/options" ) @@ -25,8 +24,8 @@ var OptHTTPRequestTimeout = options.NewDuration( time.Second*10, "HTTP/S request initial response timeout duration", `How long to wait for initial response from a HTTP/S endpoint before -timeout occurs. Reading the body of the response, such as large HTTP file -downloads, is not affected by this option. Example: 500ms or 3s. +timeout occurs. Reading the body of the response, such as a large HTTP file +download, is not affected by this option. Example: 500ms or 3s. Contrast with http.response.timeout.`, options.TagSource, ) @@ -38,7 +37,7 @@ var OptHTTPResponseTimeout = options.NewDuration( 0, "HTTP/S request completion timeout duration", `How long to wait for the entire HTTP transaction to complete. This includes -reading the body of the response, such as large HTTP file downloads. Typically +reading the body of the response, such as a large HTTP file download. Typically this is set to 0, indicating no timeout. Contrast with http.request.timeout.`, options.TagSource, ) diff --git a/libsq/source/files.go b/libsq/source/files.go index 0e359739b..515539965 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -11,16 +11,15 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" - - "github.com/neilotoole/sq/libsq/core/ioz/download" - "github.com/neilotoole/fscache" + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/contextio" + "github.com/neilotoole/sq/libsq/core/ioz/download" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" diff --git a/testh/testh.go b/testh/testh.go index 5ba7cc7af..14e9e14b3 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -14,8 +14,6 @@ import ( "testing" "time" - "github.com/neilotoole/sq/testh/tu" - "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -56,6 +54,7 @@ import ( "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/testsrc" + "github.com/neilotoole/sq/testh/tu" ) // defaultDBOpenTimeout is the timeout for tests to open (and ping) their DBs. From 34e77341ce72b56502e546731715e6bd344cd82f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 12 Jan 2024 20:52:01 -0700 Subject: [PATCH 166/195] More refactoring --- .golangci.yml | 2 +- CHANGELOG.md | 3 +- cli/cmd_cache.go | 54 ++-- cli/cmd_diff.go | 2 + cli/cmd_inspect.go | 4 +- cli/cmd_root.go | 4 +- cli/cmd_slq.go | 1 + cli/cmd_x.go | 13 +- cli/config/yamlstore/upgrade.go | 2 +- cli/config/yamlstore/yamlstore.go | 15 +- cli/diff/internal/go-udiff/myers/diff.go | 2 +- cli/output.go | 4 +- cli/run.go | 25 +- cli/testrun/testrun.go | 3 + drivers/sqlite3/sqlite3.go | 46 ++-- drivers/userdriver/xmlud/xmlimport_test.go | 8 +- drivers/xlsx/xlsx_test.go | 2 +- libsq/core/ioz/checksum/checksum.go | 4 +- libsq/core/ioz/download/cache.go | 2 +- libsq/core/ioz/download/download.go | 7 +- libsq/core/ioz/download/http.go | 15 ++ libsq/core/ioz/httpz/httpz.go | 45 ++-- libsq/core/ioz/ioz.go | 7 +- libsq/core/ioz/lockfile/lockfile.go | 3 + libsq/core/lg/devlog/devlog.go | 4 +- libsq/core/options/options.go | 7 +- libsq/driver/driver.go | 16 +- libsq/driver/grip.go | 6 - libsq/driver/grips.go | 281 +++++++++++++++------ libsq/driver/ingest.go | 6 +- libsq/pipeline.go | 14 +- libsq/query_no_src_test.go | 2 +- libsq/query_test.go | 2 +- libsq/source/cache.go | 214 ++++++++++++++-- libsq/source/files.go | 26 +- libsq/source/files_test.go | 10 + libsq/source/internal_test.go | 1 + testh/testh.go | 40 ++- testh/tu/tu.go | 8 + 39 files changed, 654 insertions(+), 256 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index addbcd38b..3d6fcc2c9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -231,7 +231,7 @@ linters-settings: arguments: [ 7 ] # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#comment-spacings - name: comment-spacings - disabled: false + disabled: true arguments: - mypragma - otherpragma diff --git a/CHANGELOG.md b/CHANGELOG.md index 4eb4a4f96..2c141d2de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,8 @@ Breaking changes are annotated with ☢️, and alpha/beta features with 🐥. - Long-running operations (such as data ingestion, or file download) now result in a progress bar being displayed. Display of the progress bar is controlled by the new config options [`progress`](https://sq.io/docs/config#progress) - and [`progress.delay`](https://sq.io/docs/config#progressdelay). + and [`progress.delay`](https://sq.io/docs/config#progressdelay). You can also use + the `--no-progress` flag to disable the progress bar. - Ingested [document sources](https://sq.io/docs/concepts#document-source) (such as [CSV](https://sq.io/docs/drivers/csv) or [Excel](https://sq.io/docs/drivers/xlsx)) now make use of an ingest cache DB. Previously, ingestion of document source data occurred diff --git a/cli/cmd_cache.go b/cli/cmd_cache.go index 63d2dacc5..66ed5826f 100644 --- a/cli/cmd_cache.go +++ b/cli/cmd_cache.go @@ -9,7 +9,6 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/driver" - "github.com/neilotoole/sq/libsq/source" ) func newCacheCmd() *cobra.Command { @@ -49,7 +48,7 @@ func newCacheLocationCmd() *cobra.Command { Args: cobra.ExactArgs(0), RunE: execCacheLocation, Example: ` $ sq cache location - /Users/neilotoole/Library/Caches/sq`, + /Users/neilotoole/Library/Caches/sq/f36ac695`, } addTextFormatFlags(cmd) @@ -59,9 +58,8 @@ func newCacheLocationCmd() *cobra.Command { } func execCacheLocation(cmd *cobra.Command, _ []string) error { - dir := source.DefaultCacheDir() ru := run.FromContext(cmd.Context()) - return ru.Writers.Config.CacheLocation(dir) + return ru.Writers.Config.CacheLocation(ru.Files.CacheDir()) } func newCacheInfoCmd() *cobra.Command { @@ -72,7 +70,7 @@ func newCacheInfoCmd() *cobra.Command { Args: cobra.ExactArgs(0), RunE: execCacheInfo, Example: ` $ sq cache stat - /Users/neilotoole/Library/Caches/sq enabled (472.8MB)`, + /Users/neilotoole/Library/Caches/sq/f36ac695 enabled (472.8MB)`, } addTextFormatFlags(cmd) @@ -82,8 +80,9 @@ func newCacheInfoCmd() *cobra.Command { } func execCacheInfo(cmd *cobra.Command, _ []string) error { - dir := source.DefaultCacheDir() ru := run.FromContext(cmd.Context()) + dir := ru.Files.CacheDir() + size, err := ioz.DirSize(dir) if err != nil { lg.FromContext(cmd.Context()).Warn("Could not determine cache size", @@ -97,21 +96,42 @@ func execCacheInfo(cmd *cobra.Command, _ []string) error { func newCacheClearCmd() *cobra.Command { cmd := &cobra.Command{ - Use: "clear", - Short: "Clear cache", - Long: "Clear cache. May cause issues if another sq instance is running.", - Args: cobra.ExactArgs(0), - RunE: execCacheClear, - Example: ` $ sq cache clear`, + Use: "clear [@HANDLE]", + Short: "Clear cache", + Long: "Clear cache for source or entire cache.", + Args: cobra.MaximumNArgs(1), + ValidArgsFunction: completeHandle(1), + RunE: execCacheClear, + Example: ` # Clear entire cache + $ sq cache clear + + # Clear cache for @sakila + $ sq cache clear @sakila`, } markCmdRequiresConfigLock(cmd) return cmd } -func execCacheClear(cmd *cobra.Command, _ []string) error { - ru := run.FromContext(cmd.Context()) - return ru.Files.CacheClear(cmd.Context()) +func execCacheClear(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + ru := run.FromContext(ctx) + if len(args) == 0 { + return ru.Files.CacheClearAll(ctx) + } + + src, err := ru.Config.Collection.Get(args[0]) + if err != nil { + return err + } + + unlock, err := ru.Files.CacheLockAcquire(ctx, src) + if err != nil { + return err + } + defer unlock() + + return ru.Files.CacheClearSource(ctx, src, true) } func newCacheTreeCmd() *cobra.Command { @@ -128,13 +148,14 @@ func newCacheTreeCmd() *cobra.Command { $ sq cache tree --size`, } + markCmdRequiresConfigLock(cmd) _ = cmd.Flags().BoolP(flag.CacheTreeSize, flag.CacheTreeSizeShort, false, flag.CacheTreeSizeUsage) return cmd } func execCacheTree(cmd *cobra.Command, _ []string) error { ru := run.FromContext(cmd.Context()) - cacheDir := source.DefaultCacheDir() + cacheDir := ru.Files.CacheDir() if !ioz.DirExists(cacheDir) { return nil } @@ -156,6 +177,7 @@ func newCacheEnableCmd() *cobra.Command { }, Example: ` $ sq cache enable`, } + markCmdRequiresConfigLock(cmd) return cmd } diff --git a/cli/cmd_diff.go b/cli/cmd_diff.go index 789fb0910..ee0fe4b04 100644 --- a/cli/cmd_diff.go +++ b/cli/cmd_diff.go @@ -1,6 +1,7 @@ package cli import ( + "github.com/neilotoole/sq/libsq/driver" "github.com/samber/lo" "github.com/spf13/cobra" @@ -189,6 +190,7 @@ The default (3) can be changed via: completeStrings(-1, stringz.Strings(diffFormats)...), )) + addOptionFlag(cmd.Flags(), driver.OptIngestCache) return cmd } diff --git a/cli/cmd_inspect.go b/cli/cmd_inspect.go index 3a6b4789d..712eb0a8a 100644 --- a/cli/cmd_inspect.go +++ b/cli/cmd_inspect.go @@ -5,6 +5,8 @@ import ( "database/sql" "slices" + "github.com/neilotoole/sq/libsq/driver" + "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/flag" @@ -102,7 +104,7 @@ formats both show extensive detail.`, cmd.Flags().String(flag.ActiveSchema, "", flag.ActiveSchemaUsage) panicOn(cmd.RegisterFlagCompletionFunc(flag.ActiveSchema, activeSchemaCompleter{getActiveSourceViaArgs}.complete)) - + addOptionFlag(cmd.Flags(), driver.OptIngestCache) return cmd } diff --git a/cli/cmd_root.go b/cli/cmd_root.go index 5b4abf856..5bafc3970 100644 --- a/cli/cmd_root.go +++ b/cli/cmd_root.go @@ -25,8 +25,8 @@ database table. You can query using sq's own jq-like syntax, or in native SQL. Use "sq inspect" to view schema metadata. Use the "sq tbl" commands -to copy, truncate and drop tables. Use "sq diff" to compare source metadata -and row data. +to copy, truncate and drop tables. Use "sq diff" to compare source +metadata and row data. See docs and more: https://sq.io`, Example: ` # Add Postgres source. diff --git a/cli/cmd_slq.go b/cli/cmd_slq.go index 343117795..5e3f42ca8 100644 --- a/cli/cmd_slq.go +++ b/cli/cmd_slq.go @@ -343,6 +343,7 @@ func addQueryCmdFlags(cmd *cobra.Command) { panicOn(cmd.RegisterFlagCompletionFunc(flag.IngestDriver, completeDriverType)) cmd.Flags().Bool(flag.IngestHeader, false, flag.IngestHeaderUsage) + addOptionFlag(cmd.Flags(), driver.OptIngestCache) cmd.Flags().Bool(flag.CSVEmptyAsNull, true, flag.CSVEmptyAsNullUsage) cmd.Flags().String(flag.CSVDelim, flag.CSVDelimDefault, flag.CSVDelimUsage) panicOn(cmd.RegisterFlagCompletionFunc(flag.CSVDelim, completeStrings(-1, csv.NamedDelims()...))) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index a8b2a73b8..0ab9fb385 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -56,14 +56,9 @@ func execXLockSrcCmd(cmd *cobra.Command, args []string) error { } timeout := time.Minute * 20 - lock, err := ru.Files.CacheLockFor(src) - if err != nil { - return err - } - fmt.Fprintf(ru.Out, "Locking cache for source %s with timeout %s for %q [%d]\n\n %s\n\n", - src.Handle, timeout, os.Args[0], os.Getpid(), lock) + ru.Config.Options[source.OptCacheLockTimeout.Key()] = timeout - err = lock.Lock(ctx, timeout) + unlock, err := ru.Files.CacheLockAcquire(ctx, src) if err != nil { return err } @@ -78,9 +73,7 @@ func execXLockSrcCmd(cmd *cobra.Command, args []string) error { } fmt.Fprintf(ru.Out, "Releasing cache lock for %s\n", src.Handle) - if err = lock.Unlock(); err != nil { - return err - } + unlock() fmt.Fprintf(ru.Out, "Cache lock released for %s\n", src.Handle) return nil diff --git a/cli/config/yamlstore/upgrade.go b/cli/config/yamlstore/upgrade.go index 6c46f9ed2..f955903d8 100644 --- a/cli/config/yamlstore/upgrade.go +++ b/cli/config/yamlstore/upgrade.go @@ -60,7 +60,7 @@ func (fs *Store) doUpgrade(ctx context.Context, startVersion, targetVersion stri log.Debug("Config upgrade step successful") } - if err = fs.Write(data); err != nil { + if err = fs.write(data); err != nil { return nil, err } diff --git a/cli/config/yamlstore/yamlstore.go b/cli/config/yamlstore/yamlstore.go index 66d59eecb..40b371364 100644 --- a/cli/config/yamlstore/yamlstore.go +++ b/cli/config/yamlstore/yamlstore.go @@ -55,7 +55,7 @@ type Store struct { // Lockfile implements Store.Lockfile. func (fs *Store) Lockfile() (lockfile.Lockfile, error) { - fp := filepath.Join(filepath.Dir(fs.Path), "config.lock.pid") + fp := filepath.Join(filepath.Dir(fs.Path), "config.pid.lock") fp, err := filepath.Abs(fp) if err != nil { return "", errz.Wrap(err, "failed to get abs path for lockfile") @@ -192,20 +192,17 @@ func (fs *Store) Save(_ context.Context, cfg *config.Config) error { return err } - return fs.Write(data) + return fs.write(data) } // Write writes the config bytes to disk. -func (fs *Store) Write(data []byte) error { +func (fs *Store) write(data []byte) error { // It's possible that the parent dir of fs.Path doesn't exist. - dir := filepath.Dir(fs.Path) - err := os.MkdirAll(dir, 0o750) - if err != nil { - return errz.Wrapf(err, "failed to make parent dir of sq config file: %s", dir) + if err := ioz.RequireDir(filepath.Dir(fs.Path)); err != nil { + return errz.Wrapf(err, "failed to make parent dir of config file: %s", filepath.Dir(fs.Path)) } - err = os.WriteFile(fs.Path, data, 0o600) - if err != nil { + if err := os.WriteFile(fs.Path, data, ioz.RWPerms); err != nil { return errz.Wrap(err, "failed to save config file") } diff --git a/cli/diff/internal/go-udiff/myers/diff.go b/cli/diff/internal/go-udiff/myers/diff.go index e29e02744..028e63a0e 100644 --- a/cli/diff/internal/go-udiff/myers/diff.go +++ b/cli/diff/internal/go-udiff/myers/diff.go @@ -11,7 +11,7 @@ import ( diff "github.com/neilotoole/sq/cli/diff/internal/go-udiff" ) -// Sources: +// Grips: // https://blog.jcoglan.com/2017/02/17/the-myers-diff-algorithm-part-3/ // https://www.codeproject.com/Articles/42279/%2FArticles%2F42279%2FInvestigating-Myers-diff-algorithm-Part-1-of-2 diff --git a/cli/output.go b/cli/output.go index bffe2e3dd..129e1ba76 100644 --- a/cli/output.go +++ b/cli/output.go @@ -107,8 +107,8 @@ command, sq falls back to "text". Available formats: true, 0, true, - "Progress bar shown for long-running operations", - `Progress bar shown for long-running operations.`, + "Progress bar for long-running operations", + `Progress bar for long-running operations.`, options.TagOutput, ) diff --git a/cli/run.go b/cli/run.go index 970f14b3c..924a9b4e9 100644 --- a/cli/run.go +++ b/cli/run.go @@ -8,6 +8,9 @@ import ( "path/filepath" "time" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" + "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/config" @@ -144,9 +147,27 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { } var err error + // The Files instance may already have been created. If not, create it. if ru.Files == nil { - // The Files instance may already have been created. If not, create it. - ru.Files, err = source.NewFiles(ctx, ru.OptionsRegistry, source.DefaultTempDir(), source.DefaultCacheDir(), true) + var cfgLock lockfile.Lockfile + if cfgLock, err = ru.ConfigStore.Lockfile(); err != nil { + return err + } + cfgLockFunc := source.NewLockFunc(cfgLock, "acquire config lock", config.OptConfigLockTimeout) + + // We use cache and temp dirs with paths based on a hash of the config's + // location. This ensures that multiple sq instances using different + // configs don't share the same cache/temp dir. + sum := checksum.Sum([]byte(ru.ConfigStore.Location())) + + ru.Files, err = source.NewFiles( + ctx, + ru.OptionsRegistry, + cfgLockFunc, + filepath.Join(source.DefaultTempDir(), sum), + filepath.Join(source.DefaultCacheDir(), sum), + true, + ) if err != nil { lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run) return err diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index 85b56eea9..7ad3e851a 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -12,6 +12,8 @@ import ( "sync" "testing" + "github.com/neilotoole/sq/testh" + "github.com/stretchr/testify/require" "github.com/neilotoole/sq/cli" @@ -116,6 +118,7 @@ func newRun(ctx context.Context, t testing.TB, cfgStore config.Store) (ru *run.R ru.Files, err = source.NewFiles( ctx, ru.OptionsRegistry, + testh.TempLockFunc(t), tu.TempDir(t, false), tu.CacheDir(t, false), true, diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index da8f94518..2ef66f1a5 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -918,10 +918,11 @@ type grip struct { src *source.Source drvr *driveri - // DEBUG: closeMu and closed exist while debugging close behavior. - // We should be able to get rid of them eventually. - closeMu sync.Mutex - closed bool + // closeOnce and closeErr are used to ensure that Close is only called once. + // This is particularly relevant to sqlite, as calling Close multiple times + // can cause problems on Windows. + closeOnce sync.Once + closeErr error } // DB implements driver.Grip. @@ -1006,20 +1007,15 @@ func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return md, nil } -// Close implements driver.Grip. +// Close implements driver.Grip. Subsequent calls to Close are no-op and +// return the same error. func (g *grip) Close() error { - g.closeMu.Lock() - defer g.closeMu.Unlock() + g.closeOnce.Do(func() { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) + g.closeErr = errw(g.db.Close()) + }) - if g.closed { - g.log.Warn("SQLite DB already closed", lga.Src, g.src) - return nil - } - - g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - err := errw(g.db.Close()) - g.closed = true - return err + return g.closeErr } var _ driver.ScratchSrcFunc = NewScratchSource @@ -1027,28 +1023,26 @@ var _ driver.ScratchSrcFunc = NewScratchSource // NewScratchSource returns a new scratch src. The supplied fpath // must be the absolute path to the location to create the SQLite DB file, // typically in the user cache dir. -// The returned clnup func will delete the file. +// The returned clnup func will delete the dB file. func NewScratchSource(ctx context.Context, fpath string) (src *source.Source, clnup func() error, err error) { log := lg.FromContext(ctx) - - log.Debug("Created sqlite3 scratchdb data file", lga.Path, fpath) - src = &source.Source{ Type: Type, Handle: source.ScratchHandle, Location: Prefix + fpath, } - fn := func() error { - log.Debug("Deleting sqlite3 scratchdb file", lga.Src, src, lga.Path, fpath) - rmErr := errz.Err(os.Remove(fpath)) - if rmErr != nil { - log.Warn("Delete sqlite3 scratchdb file", lga.Err, rmErr) + clnup = func() error { + log.Debug("Delete sqlite3 scratchdb file", lga.Src, src, lga.Path, fpath) + if err := os.Remove(fpath); err != nil { + log.Warn("Delete sqlite3 scratchdb file", lga.Err, err) + return errz.Err(err) } + return nil } - return src, fn, nil + return src, clnup, nil } // PathFromLocation returns the absolute file path from the source location, diff --git a/drivers/userdriver/xmlud/xmlimport_test.go b/drivers/userdriver/xmlud/xmlimport_test.go index cae6b6348..4e21e70ec 100644 --- a/drivers/userdriver/xmlud/xmlimport_test.go +++ b/drivers/userdriver/xmlud/xmlimport_test.go @@ -11,8 +11,6 @@ import ( "github.com/neilotoole/sq/drivers/userdriver/xmlud" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/stringz" - "github.com/neilotoole/sq/libsq/source" - "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/testsrc" @@ -33,8 +31,7 @@ func TestImport_Ppl(t *testing.T) { require.Equal(t, driverPpl, udDef.Name) require.Equal(t, xmlud.Genre, udDef.Genre) - src := &source.Source{Handle: "@ppl_" + stringz.Uniq8(), Type: drivertype.None} - scratchDB, err := th.Sources().OpenScratch(th.Context, src) + scratchDB, err := th.Grips().OpenEphemeral(th.Context) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, scratchDB.Close()) @@ -79,8 +76,7 @@ func TestImport_RSS(t *testing.T) { require.Equal(t, driverRSS, udDef.Name) require.Equal(t, xmlud.Genre, udDef.Genre) - src := &source.Source{Handle: "@rss_" + stringz.Uniq8(), Type: drivertype.None} - scratchDB, err := th.Sources().OpenScratch(th.Context, src) + scratchDB, err := th.Grips().OpenEphemeral(th.Context) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, scratchDB.Close()) diff --git a/drivers/xlsx/xlsx_test.go b/drivers/xlsx/xlsx_test.go index ceb2e9ab1..f5c9b4491 100644 --- a/drivers/xlsx/xlsx_test.go +++ b/drivers/xlsx/xlsx_test.go @@ -162,7 +162,7 @@ func TestOpenFileFormats(t *testing.T) { Location: filepath.Join("testdata", "file_formats", tc.filename), }) - grip, err := th.Sources().Open(th.Context, src) + grip, err := th.Grips().Open(th.Context, src) if tc.wantErr { require.Error(t, err) return diff --git a/libsq/core/ioz/checksum/checksum.go b/libsq/core/ioz/checksum/checksum.go index 1bfaf8b09..34be28855 100644 --- a/libsq/core/ioz/checksum/checksum.go +++ b/libsq/core/ioz/checksum/checksum.go @@ -12,6 +12,8 @@ import ( "strconv" "strings" + "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/errz" ) @@ -53,7 +55,7 @@ func Write(w io.Writer, sum Checksum, name string) error { // // See: Write. func WriteFile(path string, sum Checksum, name string) error { - f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, ioz.RWPerms) if err != nil { return errz.Wrap(err, "write checksum file") } diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index f7d7a38f4..46287a26c 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -281,7 +281,7 @@ func (c *cache) write(ctx context.Context, resp *http.Response, return 0, nil } - cacheFile, err := os.OpenFile(fpBody, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.ModePerm) + cacheFile, err := os.OpenFile(fpBody, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, ioz.RWPerms) if err != nil { return 0, err } diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index 91ceb2773..d4b0244a4 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -15,6 +15,7 @@ import ( "os" "path/filepath" "sync" + "time" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" @@ -314,9 +315,11 @@ func (dl *Download) get(req *http.Request, h Handler) { //nolint:funlen,gocognit // do executes the request. func (dl *Download) do(req *http.Request) (*http.Response, error) { - bar := progress.FromContext(req.Context()).NewWaiter(dl.name+": start download", true) + ctx := req.Context() + bar := progress.FromContext(ctx).NewWaiter(dl.name+": start download", true) + start := time.Now() resp, err := dl.c.Do(req) - httpz.Log(req, resp, err) + logResp(resp, time.Since(start), err) bar.Stop() if err != nil { // Download timeout errors are typically wrapped in an url.Error, resulting diff --git a/libsq/core/ioz/download/http.go b/libsq/core/ioz/download/http.go index 0e00649d3..1b815fc8c 100644 --- a/libsq/core/ioz/download/http.go +++ b/libsq/core/ioz/download/http.go @@ -7,6 +7,10 @@ import ( "net/http" "strings" "time" + + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lga" ) // errNoDateHeader indicates that the HTTP headers contained no Date header. @@ -284,3 +288,14 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { } return true } + +func logResp(resp *http.Response, elapsed time.Duration, err error) { + ctx := resp.Request.Context() + log := lg.FromContext(ctx).With("response_time", elapsed) + if err != nil { + log.Warn("HTTP request error", lga.Err, err) + return + } + + log.Info("HTTP request completed", lga.Resp, httpz.ResponseLogValue(resp)) +} diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index 827a74d15..e1502d3ab 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -22,8 +22,6 @@ import ( "strconv" "strings" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/stringz" ) @@ -98,49 +96,50 @@ func ResponseLogValue(resp *http.Response) slog.Value { return slog.Value{} } - attrs := []slog.Attr{ - slog.String("proto", resp.Proto), - slog.String("status", resp.Status), + var attrs []slog.Attr + if resp.Request != nil { + attrs = append(attrs, + slog.String("method", resp.Request.Method), + slog.String("url", resp.Request.URL.String()), + ) } + attrs = append(attrs, + slog.String("proto", resp.Proto), + slog.String("status", resp.Status)) h := resp.Header + var hAttrs []slog.Attr for k := range h { vals := h.Values(k) if len(vals) == 1 { - attrs = append(attrs, slog.String(k, vals[0])) + hAttrs = append(hAttrs, slog.String(k, vals[0])) continue } - attrs = append(attrs, slog.Any(k, h.Get(k))) + hAttrs = append(hAttrs, slog.Any(k, h.Get(k))) } - if resp.Request != nil { - attrs = append(attrs, slog.Any("req", RequestLogValue(resp.Request))) + if len(hAttrs) > 0 { + attrs = append(attrs, slog.Any("headers", slog.GroupValue(hAttrs...))) } return slog.GroupValue(attrs...) } -// Log logs req, resp, and err via the logger on req.Context(). -func Log(req *http.Request, resp *http.Response, err error) { - log := lg.FromContext(req.Context()).With(lga.Method, req.Method, lga.URL, req.URL) - if err != nil { - log.Warn("HTTP request error", lga.Err, err) - return - } - - log.Debug("HTTP request completed", lga.Resp, ResponseLogValue(resp)) -} - // RequestLogValue implements slog.LogValuer for req. func RequestLogValue(req *http.Request) slog.Value { if req == nil { return slog.Value{} } + p := req.URL.Path + if p == "" { + p = req.URL.RawPath + } + attrs := []slog.Attr{ slog.String("method", req.Method), - slog.String("path", req.URL.RawPath), + slog.String("path", p), } if req.Proto != "" { @@ -257,7 +256,7 @@ func fixPragmaCacheControl(header http.Header) { func badStringError(what, val string) error { return fmt.Errorf("%s %q", what, val) } -// StatusText is like http.StatusText, but also includes the code, e.g. "200/OK". +// StatusText is like http.StatusText, but also includes the code, e.g. "200 OK". func StatusText(code int) string { - return strconv.Itoa(code) + "/" + http.StatusText(code) + return strconv.Itoa(code) + " " + http.StatusText(code) } diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 961857e42..be0ac82dd 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -22,6 +22,9 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" ) +// RWPerms is the default file mode used for creating files. +const RWPerms = os.FileMode(0o600) + // Close is a convenience function to close c, logging a warning // if c.Close returns an error. This is useful in defer, e.g. // @@ -348,7 +351,7 @@ func DirSize(path string) (int64, error) { // RequireDir ensures that dir exists and is a directory, creating // it if necessary. func RequireDir(dir string) error { - return errz.Err(os.MkdirAll(dir, 0o750)) + return errz.Err(os.MkdirAll(dir, 0o700)) } // ReadFileToString reads the file at name and returns its contents @@ -423,7 +426,7 @@ func WriteToFile(ctx context.Context, fp string, r io.Reader) (written int64, er return 0, err } - f, err := os.Create(fp) + f, err := os.OpenFile(fp, os.O_RDWR|os.O_CREATE|os.O_TRUNC, RWPerms) if err != nil { return 0, errz.Err(err) } diff --git a/libsq/core/ioz/lockfile/lockfile.go b/libsq/core/ioz/lockfile/lockfile.go index 6d7fc9e50..78921731f 100644 --- a/libsq/core/ioz/lockfile/lockfile.go +++ b/libsq/core/ioz/lockfile/lockfile.go @@ -94,3 +94,6 @@ func (l Lockfile) Unlock() error { func (l Lockfile) String() string { return string(l) } + +// LockFunc is a function that encapsulates locking and unlocking. +type LockFunc func(ctx context.Context) (unlock func(), err error) diff --git a/libsq/core/lg/devlog/devlog.go b/libsq/core/lg/devlog/devlog.go index fd4fa020a..4940a41ed 100644 --- a/libsq/core/lg/devlog/devlog.go +++ b/libsq/core/lg/devlog/devlog.go @@ -20,8 +20,8 @@ func NewHandler(w io.Writer, lvl slog.Leveler) slog.Handler { AddSource: true, ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { switch a.Key { - case "pid": - return slog.Attr{} + // case "pid": + // return slog.Attr{} case "error": a.Key = "err" return a diff --git a/libsq/core/options/options.go b/libsq/core/options/options.go index 282933742..f0f5004d0 100644 --- a/libsq/core/options/options.go +++ b/libsq/core/options/options.go @@ -9,7 +9,7 @@ // - New types of Opt can be defined, near where they are used. // // It is noted that these requirements could probably largely be met using -// packages such as spf13/viper. AGain, this is largely an experiment. +// packages such as spf13/viper. Again, this is largely an experiment. package options import ( @@ -27,9 +27,6 @@ type contextKey struct{} // NewContext returns a context that contains the given Options. // Use FromContext to retrieve the Options. -// -// NOTE: It's questionable whether we need to engage in this context -// business with Options. This is a bit of an experiment. func NewContext(ctx context.Context, o Options) context.Context { return context.WithValue(ctx, contextKey{}, o) } @@ -253,7 +250,7 @@ type Processor interface { Process(o Options) (Options, error) } -// DeleteNil deletes any keys with nil values. +// DeleteNil returns a new Options that has any nil values removed. func DeleteNil(o Options) Options { if o == nil { return nil diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index a696ed0a9..9ad19096d 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -31,7 +31,13 @@ type Provider interface { // Driver is the core interface that must be implemented for each type // of data source. type Driver interface { - GripOpener + // Open returns a Grip instance for src. + Open(ctx context.Context, src *source.Source) (Grip, error) + + // Ping verifies that the source is reachable, or returns an error if not. + // The exact behavior of Ping is driver-dependent. Even if Ping does not + // return an error, the source may still be bad for other reasons. + Ping(ctx context.Context, src *source.Source) error // DriverMetadata returns driver metadata. DriverMetadata() Metadata @@ -41,10 +47,6 @@ type Driver interface { // the return value (the original source is not changed). An error // is returned if the source is invalid. ValidateSource(src *source.Source) (*source.Source, error) - - // Ping verifies that the source is reachable, or returns an error if not. - // The exact behavior of Ping() is driver-dependent. - Ping(ctx context.Context, src *source.Source) error } // SQLDriver is implemented by Driver instances for SQL databases. @@ -115,8 +117,8 @@ type SQLDriver interface { // // Note that db must guarantee a single connection: that is, db // must be a sql.Conn or sql.Tx. - PrepareInsertStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string, numRows int) (*StmtExecer, - error) + PrepareInsertStmt(ctx context.Context, db sqlz.DB, destTbl string, destColNames []string, + numRows int) (*StmtExecer, error) // PrepareUpdateStmt prepares a statement for updating destColNames in // destTbl, using the supplied where clause (which may be empty). diff --git a/libsq/driver/grip.go b/libsq/driver/grip.go index a93846efb..f1c642f8a 100644 --- a/libsq/driver/grip.go +++ b/libsq/driver/grip.go @@ -43,12 +43,6 @@ type Grip interface { Close() error } -// GripOpener opens a Grip. -type GripOpener interface { - // Open returns a Grip instance for src. - Open(ctx context.Context, src *source.Source) (Grip, error) -} - // GripOpenIngester opens a Grip via an ingest function. type GripOpenIngester interface { // OpenIngest opens a Grip for src by executing ingestFn, which is diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 4fd2865fe..8849efc8d 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -1,11 +1,17 @@ package driver import ( + "bytes" "context" - "strings" + "fmt" + "os" + "path/filepath" "sync" "time" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" @@ -13,13 +19,10 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" - "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" ) -var _ GripOpener = (*Grips)(nil) - // ScratchSrcFunc is a function that returns a scratch source. // The caller is responsible for invoking cleanFn. type ScratchSrcFunc func(ctx context.Context, name string) (src *source.Source, cleanFn func() error, err error) @@ -62,7 +65,13 @@ func (gs *Grips) Open(ctx context.Context, src *source.Source) (Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) gs.mu.Lock() defer gs.mu.Unlock() - return gs.doOpen(ctx, src) + + g, err := gs.doOpen(ctx, src) + if err != nil { + return nil, err + } + gs.clnup.AddC(g) + return g, nil } // DriverFor returns the driver for typ. @@ -115,107 +124,135 @@ func (gs *Grips) doOpen(ctx context.Context, src *source.Source) (Grip, error) { return nil, err } - gs.clnup.AddC(grip) - gs.grips[key] = grip return grip, nil } -// OpenScratch returns a scratch database instance. It is not +// OpenEphemeral returns an ephemeral scratch Grip instance. It is not // necessary for the caller to close the returned Grip as // its Close method will be invoked by Grips.Close. -func (gs *Grips) OpenScratch(ctx context.Context, src *source.Source) (Grip, error) { - const msgCloseScratch = "Close scratch db" +func (gs *Grips) OpenEphemeral(ctx context.Context) (Grip, error) { + const msgCloseDB = "Close ephemeral db" + gs.mu.Lock() + defer gs.mu.Unlock() log := lg.FromContext(ctx) - cacheDir, srcCacheDBFilepath, _, err := gs.files.CachePaths(src) + dir := filepath.Join(gs.files.TempDir(), fmt.Sprintf("ephemeraldb_%s_%d", stringz.Uniq8(), os.Getpid())) + if err := ioz.RequireDir(dir); err != nil { + return nil, err + } + + clnup := cleanup.New() + clnup.AddE(func() error { + return errz.Wrap(os.RemoveAll(dir), "remove ephemeral db dir") + }) + + fp := filepath.Join(dir, "ephemeral.sqlite.db") + src, cleanFn, err := gs.scratchSrcFn(ctx, fp) if err != nil { + // if err is non-nil, cleanup is guaranteed to be nil return nil, err } + src.Handle = "@ephemeral_" + stringz.Uniq8() - if err = ioz.RequireDir(cacheDir); err != nil { + clnup.AddE(cleanFn) + drvr, err := gs.drvrs.DriverFor(src.Type) + if err != nil { + lg.WarnIfFuncError(log, msgCloseDB, clnup.Run) return nil, err } + var grip Grip + if grip, err = drvr.Open(ctx, src); err != nil { + lg.WarnIfFuncError(log, msgCloseDB, clnup.Run) + return nil, err + } + + g := &cleanOnCloseGrip{ + Grip: grip, + clnup: clnup, + } + gs.clnup.AddC(g) + log.Info("Opened ephemeral db", lga.Src, g.Source()) + return g, nil +} + +func (gs *Grips) openNewCacheGrip(ctx context.Context, src *source.Source) (grip Grip, + cleanFn func() error, err error, +) { + const msgRemoveScratch = "Remove cache db" + log := lg.FromContext(ctx) + + cacheDir, srcCacheDBFilepath, _, err := gs.files.CachePaths(src) + if err != nil { + return nil, nil, err + } + + if err = ioz.RequireDir(cacheDir); err != nil { + return nil, nil, err + } + scratchSrc, cleanFn, err := gs.scratchSrcFn(ctx, srcCacheDBFilepath) if err != nil { // if err is non-nil, cleanup is guaranteed to be nil - return nil, err + return nil, nil, err } - log.Debug("Opening scratch src", lga.Src, scratchSrc) + log.Debug("Opening scratch cache src", lga.Src, scratchSrc) backingDrvr, err := gs.drvrs.DriverFor(scratchSrc.Type) if err != nil { - lg.WarnIfFuncError(log, msgCloseScratch, cleanFn) - return nil, err + lg.WarnIfFuncError(log, msgRemoveScratch, cleanFn) + return nil, nil, err } var backingGrip Grip - backingGrip, err = backingDrvr.Open(ctx, scratchSrc) - if err != nil { - lg.WarnIfFuncError(log, msgCloseScratch, cleanFn) - return nil, err - } - - allowCache := OptIngestCache.Get(options.FromContext(ctx)) - if !allowCache { - // If the ingest cache is disabled, we add the cleanup func - // so the scratch DB is deleted when the session ends. - gs.clnup.AddE(cleanFn) + if backingGrip, err = backingDrvr.Open(ctx, scratchSrc); err != nil { + lg.WarnIfFuncError(log, msgRemoveScratch, cleanFn) + // The os.Remove call may be unnecessary, but doesn't hurt. + lg.WarnIfError(log, msgRemoveScratch, os.Remove(srcCacheDBFilepath)) + return nil, nil, err } - return backingGrip, nil + log.Info("Opened new cache db", lga.Src, backingGrip.Source()) + return backingGrip, cleanFn, nil } // OpenIngest implements driver.GripOpenIngester. It opens a Grip, ingesting -// the source into the Grip. If allowCache is true, the ingest cache DB +// the source's data into the Grip. If allowCache is true, the ingest cache DB // is used if possible. If allowCache is false, any existing ingest cache DB // is not utilized, and is overwritten by the ingestion process. func (gs *Grips) OpenIngest(ctx context.Context, src *source.Source, allowCache bool, ingestFn func(ctx context.Context, dest Grip) error, ) (Grip, error) { - // Get the cache lock for src, no matter if we're making - // use of the ingest cache DB or not. We do this to prevent - // another process from overwriting the cache DB while it's - // being written to. - lock, err := gs.files.CacheLockFor(src) + log := lg.FromContext(ctx).With(lga.Handle, src.Handle) + ctx = lg.NewContext(ctx, log) + unlock, err := gs.files.CacheLockAcquire(ctx, src) if err != nil { return nil, err } - - lockTimeout := source.OptCacheLockTimeout.Get(options.FromContext(ctx)) - bar := progress.FromContext(ctx).NewTimeoutWaiter( - src.Handle+": acquire lock", - time.Now().Add(lockTimeout), - ) - - err = lock.Lock(ctx, lockTimeout) - bar.Stop() - if err != nil { - return nil, errz.Wrap(err, src.Handle+": acquire cache lock") - } - - defer func() { - if err = lock.Unlock(); err != nil { - lg.FromContext(ctx).Warn("Failed to release cache lock", - lga.Lock, lock, lga.Err, err) - } - }() + defer unlock() if !allowCache || src.Handle == source.StdinHandle { // Note that we can never cache stdin, because it's a stream // that is effectively unique each time. - return gs.openIngestNoCache(ctx, src, ingestFn) + return gs.openIngestGripNoCache(ctx, src, ingestFn) } - return gs.openIngestCache(ctx, src, ingestFn) + return gs.openIngestGripCache(ctx, src, ingestFn) } -func (gs *Grips) openIngestNoCache(ctx context.Context, src *source.Source, +func (gs *Grips) openIngestGripNoCache(ctx context.Context, src *source.Source, ingestFn func(ctx context.Context, destGrip Grip) error, ) (Grip, error) { log := lg.FromContext(ctx) - impl, err := gs.OpenScratch(ctx, src) + + // Clear any existing ingest cache (but don't delete any downloads). + // Note that we have already acquired the cache lock at this point. + if err := gs.files.CacheClearSource(ctx, src, false); err != nil { + return nil, err + } + + impl, cleanFn, err := gs.openNewCacheGrip(ctx, src) if err != nil { return nil, err } @@ -230,25 +267,32 @@ func (gs *Grips) openIngestNoCache(ctx context.Context, src *source.Source, lga.Elapsed, elapsed, lga.Err, err, ) lg.WarnIfCloseError(log, lgm.CloseDB, impl) + lg.WarnIfFuncError(log, "Remove cache DB after failed ingest", cleanFn) return nil, err } + // Because this is a no-cache situation, we need to clear the + // cache db on close. + g := &cleanOnCloseGrip{ + Grip: impl, + clnup: cleanup.New().AddE(cleanFn), + } + log.Info("Ingest completed", - lga.Src, src, lga.Dest, impl.Source(), + lga.Src, src, lga.Dest, g.Source(), lga.Elapsed, elapsed) - return impl, nil + return g, nil } -func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, +func (gs *Grips) openIngestGripCache(ctx context.Context, src *source.Source, ingestFn func(ctx context.Context, destGrip Grip) error, ) (Grip, error) { - log := lg.FromContext(ctx).With(lga.Handle, src.Handle) - ctx = lg.NewContext(ctx, log) + log := lg.FromContext(ctx) var impl Grip var foundCached bool var err error - if impl, foundCached, err = gs.openCachedFor(ctx, src); err != nil { + if impl, foundCached, err = gs.openCachedGripFor(ctx, src); err != nil { return nil, err } if foundCached { @@ -260,7 +304,8 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, log.Debug("Ingest cache MISS: no cache for source", lga.Src, src) - impl, err = gs.OpenScratch(ctx, src) + var cleanFn func() error + impl, cleanFn, err = gs.openNewCacheGrip(ctx, src) if err != nil { return nil, err } @@ -270,27 +315,28 @@ func (gs *Grips) openIngestCache(ctx context.Context, src *source.Source, elapsed := time.Since(start) if err != nil { - log.Error("Ingest failed", - lga.Src, src, lga.Dest, impl.Source(), - lga.Elapsed, elapsed, lga.Err, err, - ) + log.Error("Ingest failed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed, lga.Err, err) lg.WarnIfCloseError(log, lgm.CloseDB, impl) + lg.WarnIfFuncError(log, "Remove cache DB after failed ingest", cleanFn) return nil, err } log.Info("Ingest completed", lga.Src, src, lga.Dest, impl.Source(), lga.Elapsed, elapsed) if err = gs.files.WriteIngestChecksum(ctx, src, impl.Source()); err != nil { - log.Warn("Failed to write checksum for source file; caching not in effect", + log.Error("Failed to write checksum for cache DB", lga.Src, src, lga.Dest, impl.Source(), lga.Err, err) + lg.WarnIfCloseError(log, lgm.CloseDB, impl) + lg.WarnIfFuncError(log, "Remove cache DB after failed ingest checksum write", cleanFn) + return nil, err } return impl, nil } -// openCachedFor returns the cached backing grip for src. +// openCachedGripFor returns the cached backing grip for src. // If not cached, exists returns false. -func (gs *Grips) openCachedFor(ctx context.Context, src *source.Source) (backingGrip Grip, exists bool, err error) { +func (gs *Grips) openCachedGripFor(ctx context.Context, src *source.Source) (backingGrip Grip, exists bool, err error) { var backingSrc *source.Source backingSrc, exists, err = gs.files.CachedBackingSourceFor(ctx, src) if err != nil { @@ -307,26 +353,95 @@ func (gs *Grips) openCachedFor(ctx context.Context, src *source.Source) (backing return backingGrip, true, nil } -// OpenJoin opens an appropriate database for use as +// OpenJoin opens an appropriate Grip for use as // a work DB for joining across sources. // -// Note: There is much work to be done on this method. At this time, only -// two sources are supported. Ultimately OpenJoin should be able to -// inspect the join srcs and use heuristics to determine the best -// location for the join to occur (to minimize copying of data for -// the join etc.). Currently the implementation simply delegates -// to OpenScratch. +// REVISIT: There is much work to be done on this method. Ultimately OpenJoin +// should be able to inspect the join srcs and use heuristics to determine +// the best location for the join to occur (to minimize copying of data for +// the join etc.). func (gs *Grips) OpenJoin(ctx context.Context, srcs ...*source.Source) (Grip, error) { - var names []string + const msgCloseJoinDB = "Close join db" + gs.mu.Lock() + defer gs.mu.Unlock() + + log := lg.FromContext(ctx) + + var buf bytes.Buffer for _, src := range srcs { - names = append(names, src.Handle[1:]) + buf.WriteString(src.Handle) + } + + sum := checksum.Sum(buf.Bytes()) + dir := filepath.Join(gs.files.TempDir(), fmt.Sprintf("joindb_%s_%s_%d", sum, stringz.Uniq8(), os.Getpid())) + if err := ioz.RequireDir(dir); err != nil { + return nil, err + } + + clnup := cleanup.New() + clnup.AddE(func() error { + err := errz.Wrap(os.RemoveAll(dir), "remove join db dir") + if err != nil { + lg.FromContext(ctx).Warn("Failed to remove join db dir", lga.Path, dir, lga.Err, err) + return err + } + + lg.FromContext(ctx).Debug("Removed join db dir", lga.Path, dir) + return nil + }) + + fp := filepath.Join(dir, "join.sqlite.db") + joinSrc, cleanFn, err := gs.scratchSrcFn(ctx, fp) + if err != nil { + // if err is non-nil, cleanup is guaranteed to be nil + return nil, err } + joinSrc.Handle = "@join_" + stringz.Uniq8() - lg.FromContext(ctx).Debug("OpenJoin", "sources", strings.Join(names, ",")) - return gs.OpenScratch(ctx, srcs[0]) + clnup.AddE(cleanFn) + drvr, err := gs.drvrs.DriverFor(joinSrc.Type) + if err != nil { + lg.WarnIfFuncError(log, msgCloseJoinDB, clnup.Run) + return nil, err + } + + log.Debug("Opening join db", lga.Path, fp) + var grip Grip + if grip, err = drvr.Open(ctx, joinSrc); err != nil { + lg.WarnIfFuncError(log, msgCloseJoinDB, clnup.Run) + return nil, err + } + + g := &cleanOnCloseGrip{ + Grip: grip, + clnup: clnup, + } + gs.clnup.AddC(g) + return g, nil } -// Close closes d, invoking Close on any instances opened via d.Open. +// Close closes gs, invoking any cleanup funcs. func (gs *Grips) Close() error { return gs.clnup.Run() } + +var _ Grip = (*cleanOnCloseGrip)(nil) + +// cleanOnCloseGrip is Grip decorator, invoking clnup after the backing Grip is +// closed, thus permitting arbitrary cleanup on Grip.Close. Subsequent +// invocations of Close are no-ops and return the same error. +type cleanOnCloseGrip struct { + Grip + once sync.Once + closeErr error + clnup *cleanup.Cleanup +} + +// Close implements Grip. It invokes the underlying Grip's Close +// method, and then the closeFn, returning a combined error. +func (g *cleanOnCloseGrip) Close() error { + g.once.Do(func() { + g.closeErr = errz.Append(g.Grip.Close(), g.clnup.Run()) + }) + return g.closeErr +} diff --git a/libsq/driver/ingest.go b/libsq/driver/ingest.go index 49c426052..6388c6afb 100644 --- a/libsq/driver/ingest.go +++ b/libsq/driver/ingest.go @@ -28,11 +28,11 @@ to detect the header.`, // OptIngestCache specifies whether ingested data is cached or not. var OptIngestCache = options.NewBool( "ingest.cache", - "", - false, + "no-cache", + true, 0, true, - "Ingest data is cached", + "Cache ingest data", `Specifies whether ingested data is cached or not. When data is ingested from a document source, it is stored in a cache DB. Subsequent uses of that same source will use that cached DB instead of ingesting the data again, unless this diff --git a/libsq/pipeline.go b/libsq/pipeline.go index 7be25a929..b4df1d2a6 100644 --- a/libsq/pipeline.go +++ b/libsq/pipeline.go @@ -18,11 +18,9 @@ import ( "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlmodel" "github.com/neilotoole/sq/libsq/core/sqlz" - "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/core/tablefq" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" - "github.com/neilotoole/sq/libsq/source/drivertype" ) // pipeline is used to execute a SLQ query, @@ -186,16 +184,8 @@ func (p *pipeline) prepareNoTable(ctx context.Context, qm *queryModel) error { if handle == "" { src = p.qc.Collection.Active() if src == nil || !p.qc.Grips.IsSQLSource(src) { - log.Debug("No active SQL source, will use scratchdb.") - // REVISIT: Grips.OpenScratch needs a source, so we just make one up. - ephemeralSrc := &source.Source{ - Type: drivertype.None, - Handle: "@scratch_" + stringz.Uniq8(), - } - - // FIXME: We really want to change the signature of OpenScratch to - // just need a name, not a source. - p.targetGrip, err = p.qc.Grips.OpenScratch(ctx, ephemeralSrc) + log.Debug("No active SQL source, will use an ephemeral db.") + p.targetGrip, err = p.qc.Grips.OpenEphemeral(ctx) if err != nil { return err } diff --git a/libsq/query_no_src_test.go b/libsq/query_no_src_test.go index 32cc246c9..dece7ea33 100644 --- a/libsq/query_no_src_test.go +++ b/libsq/query_no_src_test.go @@ -34,7 +34,7 @@ func TestQuery_no_source(t *testing.T) { t.Logf("\nquery: %s\n want: %s", tc.in, tc.want) th := testh.New(t) coll := th.NewCollection() - sources := th.Sources() + sources := th.Grips() qc := &libsq.QueryContext{ Collection: coll, diff --git a/libsq/query_test.go b/libsq/query_test.go index ea1bbeb0c..1a4dbb0d0 100644 --- a/libsq/query_test.go +++ b/libsq/query_test.go @@ -163,7 +163,7 @@ func doExecQueryTestCase(t *testing.T, tc queryTestCase) { require.NoError(t, err) th := testh.New(t) - sources := th.Sources() + sources := th.Grips() qc := &libsq.QueryContext{ Collection: coll, diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 338b6d65a..03137d117 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -9,6 +9,8 @@ import ( "strings" "time" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" @@ -34,6 +36,18 @@ var OptCacheLockTimeout = options.NewDuration( if the lock is already held by another process. If zero, no retry occurs.`, ) +// CacheDir returns the cache dir. It is not guaranteed that the +// returned dir exists. +func (fs *Files) CacheDir() string { + return fs.cacheDir +} + +// TempDir returns the temp dir. It is not guaranteed that the +// returned dir exists. +func (fs *Files) TempDir() string { + return fs.tempDir +} + // CacheDirFor gets the cache dir for handle. It is not guaranteed // that the returned dir exists or is accessible. func (fs *Files) CacheDirFor(src *Source) (dir string, err error) { @@ -231,8 +245,8 @@ func (fs *Files) sourceHash(src *Source) string { return sum } -// CacheLockFor returns the lock file for src's cache. -func (fs *Files) CacheLockFor(src *Source) (lockfile.Lockfile, error) { +// cacheLockFor returns the lock file for src's cache. +func (fs *Files) cacheLockFor(src *Source) (lockfile.Lockfile, error) { cacheDir, err := fs.CacheDirFor(src) if err != nil { return "", errz.Wrapf(err, "cache lock for %s", src.Handle) @@ -246,13 +260,88 @@ func (fs *Files) CacheLockFor(src *Source) (lockfile.Lockfile, error) { return lf, nil } -// CacheClear clears the cache dir. This wipes the entire contents -// of the cache dir, so it should be used with caution. Note that -// this operation is distinct from [Files.CacheSweep]. -func (fs *Files) CacheClear(ctx context.Context) error { +// CacheLockAcquire acquires the cache lock for src. The caller must invoke +// the returned unlock func. +func (fs *Files) CacheLockAcquire(ctx context.Context, src *Source) (unlock func(), err error) { + lock, err := fs.cacheLockFor(src) + if err != nil { + return nil, err + } + + lockTimeout := OptCacheLockTimeout.Get(options.FromContext(ctx)) + log := lg.FromContext(ctx).With(lga.Src, src, lga.Timeout, lockTimeout, lga.Lock, lock) + log.Debug("Acquiring cache lock for source") + + bar := progress.FromContext(ctx).NewTimeoutWaiter( + src.Handle+": acquire lock", + time.Now().Add(lockTimeout), + ) + + err = lock.Lock(ctx, lockTimeout) + bar.Stop() + if err != nil { + return nil, errz.Wrap(err, src.Handle+": acquire cache lock") + } + + return func() { + if err = lock.Unlock(); err != nil { + log.Warn("Failed to release cache lock", lga.Err, err) + } + }, nil +} + +// CacheClearAll clears the entire cache dir. +// Note that this operation is distinct from [Files.doCacheSweep]. +func (fs *Files) CacheClearAll(ctx context.Context) error { fs.mu.Lock() defer fs.mu.Unlock() + return fs.doCacheClearAll(ctx) +} + +// CacheClearSource clears the ingest cache for src. If arg downloads is true, +// the source's download dir is also cleared. The caller should typically +// first acquire the cache lock for src via [Files.cacheLockFor]. +func (fs *Files) CacheClearSource(ctx context.Context, src *Source, clearDownloads bool) error { + fs.mu.Lock() + defer fs.mu.Unlock() + return fs.doCacheClearSource(ctx, src, clearDownloads) +} + +func (fs *Files) doCacheClearSource(ctx context.Context, src *Source, clearDownloads bool) error { + cacheDir, err := fs.CacheDirFor(src) + if err != nil { + return err + } + + entries, err := os.ReadDir(cacheDir) + if err != nil { + return errz.Wrapf(err, "%s: clear cache", src.Handle) + } + + for _, entry := range entries { + switch entry.Name() { + case "pid.lock": + continue + case "download": + if !clearDownloads { + continue + } + default: + } + + if err = os.RemoveAll(filepath.Join(cacheDir, entry.Name())); err != nil { + return errz.Wrapf(err, "%s: clear cache", src.Handle) + } + } + + lg.FromContext(ctx). + With("clear_downloads", clearDownloads, lga.Src, src, lga.Dir, cacheDir). + Info("Cleared source cache") + return nil +} + +func (fs *Files) doCacheClearAll(ctx context.Context) error { log := lg.FromContext(ctx).With(lga.Dir, fs.cacheDir) log.Debug("Clearing cache dir") if !ioz.DirExists(fs.cacheDir) { @@ -286,20 +375,24 @@ func (fs *Files) CacheClear(ctx context.Context) error { return nil } -// CacheSweep sweeps the cache dir, making a best-effort attempt +// doCacheSweep sweeps the cache dir, making a best-effort attempt // to remove any empty directories. Note that this operation is -// distinct from [Files.CacheClear]. +// distinct from [Files.CacheClearAll]. // -// REVISIT: This doesn't really do anything useful. It should instead -// sweep any abandoned cache dirs, i.e. cache dirs that don't have -// an associated source. -func (fs *Files) CacheSweep(ctx context.Context) { - fs.mu.Lock() - defer fs.mu.Unlock() - +// REVISIT: This doesn't really do as much as desired. It should +// also be able to detect orphaned src cache dirs and delete those. +func (fs *Files) doCacheSweep(ctx context.Context) { dir := fs.cacheDir log := lg.FromContext(ctx).With(lga.Dir, dir) - log.Debug("Sweeping cache dir") + log.Debug("Sweep cache dir: acquiring config lock") + + if unlock, err := fs.cfgLockFn(ctx); err != nil { + log.Error("Sweep cache dir: failed to acquire config lock", lga.Lock, fs.cfgLockFn, lga.Err, err) + return + } else { + defer unlock() + } + var count int err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { select { @@ -362,3 +455,92 @@ func DefaultCacheDir() (dir string) { func DefaultTempDir() (dir string) { return filepath.Join(os.TempDir(), "sq") } + +func NewLockFunc(lock lockfile.Lockfile, msg string, timeoutOpt options.Duration) lockfile.LockFunc { + return func(ctx context.Context) (unlock func(), err error) { + lockTimeout := timeoutOpt.Get(options.FromContext(ctx)) + bar := progress.FromContext(ctx).NewTimeoutWaiter( + msg, + time.Now().Add(lockTimeout), + ) + err = lock.Lock(ctx, lockTimeout) + bar.Stop() + if err != nil { + return nil, errz.Wrap(err, msg) + } + return func() { + if err = lock.Unlock(); err != nil { + lg.FromContext(ctx).With(lga.Lock, lock, "for", msg). + Warn("Failed to release lock", lga.Err, err) + } + }, nil + } +} + +// pruneEmptyDirTree prunes empty dirs, and dirs that contain only +// other empty dirs, from the directory tree rooted at dir. Arg dir +// must be an absolute path. +func pruneEmptyDirTree(ctx context.Context, dir string) (count int, err error) { + return doPruneEmptyDirTree(ctx, dir, true) +} + +func doPruneEmptyDirTree(ctx context.Context, dir string, isRoot bool) (count int, err error) { + if !filepath.IsAbs(dir) { + return 0, errz.Errorf("dir must be absolute: %s", dir) + } + + select { + case <-ctx.Done(): + return 0, errz.Err(ctx.Err()) + default: + } + + var entries []os.DirEntry + if entries, err = os.ReadDir(dir); err != nil { + return 0, errz.Err(err) + } + + if len(entries) == 0 { + if isRoot { + return 0, nil + } + err = os.RemoveAll(dir) + if err != nil { + return 0, errz.Err(err) + } + return 1, nil + } + + // We've got some entries... let's check what they are. + if countNonDirs(entries) != 0 { + // There are some non-dir entries, so this dir doesn't get deleted. + return 0, nil + } + + // Each of the entries is a dir. Recursively prune. + var n int + for _, entry := range entries { + select { + case <-ctx.Done(): + return count, errz.Err(ctx.Err()) + default: + } + + n, err = doPruneEmptyDirTree(ctx, filepath.Join(dir, entry.Name()), false) + count += n + if err != nil { + return count, err + } + } + + return count, nil +} + +func countNonDirs(entries []os.DirEntry) (count int) { + for _, entry := range entries { + if !entry.IsDir() { + count++ + } + } + return count +} diff --git a/libsq/source/files.go b/libsq/source/files.go index 515539965..d24724b3e 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -11,6 +11,8 @@ import ( "sync" "time" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" + "github.com/neilotoole/fscache" "github.com/neilotoole/sq/libsq/core/cleanup" @@ -47,6 +49,9 @@ type Files struct { clnup *cleanup.Cleanup optRegistry *options.Registry + // cfgLockFn is the lock func for sq's config. + cfgLockFn lockfile.LockFunc + // downloads is a map of source handles the download.Download // for that source. downloads map[string]*download.Download @@ -71,7 +76,11 @@ type Files struct { // NewFiles returns a new Files instance. If cleanFscache is true, the fscache // is cleaned on Files.Close. -func NewFiles(ctx context.Context, optReg *options.Registry, tmpDir, cacheDir string, cleanFscache bool, +func NewFiles(ctx context.Context, + optReg *options.Registry, + cfgLock lockfile.LockFunc, + tmpDir, cacheDir string, + cleanFscache bool, ) (*Files, error) { log := lg.FromContext(ctx) log.Debug("Creating new Files instance", "tmp_dir", tmpDir, "cache_dir", cacheDir) @@ -83,12 +92,13 @@ func NewFiles(ctx context.Context, optReg *options.Registry, tmpDir, cacheDir st fs := &Files{ optRegistry: optReg, cacheDir: cacheDir, - fscacheEntryMetas: make(map[string]*fscacheEntryMeta), + cfgLockFn: cfgLock, tempDir: tmpDir, clnup: cleanup.New(), log: lg.FromContext(ctx), downloads: map[string]*download.Download{}, fillerWgs: &sync.WaitGroup{}, + fscacheEntryMetas: make(map[string]*fscacheEntryMeta), } // We want a unique dir for each execution. Note that fcache is deleted @@ -117,8 +127,11 @@ func NewFiles(ctx context.Context, optReg *options.Registry, tmpDir, cacheDir st } fs.clnup.AddE(fs.fscache.Clean) - // REVISIT: We could automatically sweep the cache dir on Close? - // fs.clnup.Add(func() { fs.CacheSweep(ctx) }) + fs.clnup.AddE(func() error { + return errz.Wrap(os.RemoveAll(fs.tempDir), "remove files temp dir") + }) + + fs.clnup.Add(func() { fs.doCacheSweep(ctx) }) return fs, nil } @@ -415,7 +428,7 @@ func (fs *Files) Ping(ctx context.Context, src *Source) error { } defer lg.WarnIfCloseError(fs.log, lgm.CloseHTTPResponseBody, resp.Body) if resp.StatusCode != http.StatusOK { - return errz.Errorf("ping: %s: expected %s but got %s", + return errz.Errorf("ping: %s: expected {%s} but got {%s}", src.Handle, httpz.StatusText(http.StatusOK), httpz.StatusText(resp.StatusCode)) } return nil @@ -435,6 +448,9 @@ func (fs *Files) Close() error { fs.log.Debug("Files.Close: waiting for goroutines to complete") fs.fillerWgs.Wait() + // TODO: Should delete the tmp dir + // TODO: Should sweep the cache + fs.log.Debug("Files.Close: executing clean funcs", lga.Count, fs.clnup.Len()) return fs.clnup.Run() } diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index c49d46550..e4af7fb47 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -59,6 +59,7 @@ func TestFiles_Type(t *testing.T) { fs, err := source.NewFiles( ctx, nil, + testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true), true, @@ -107,6 +108,7 @@ func TestFiles_DetectType(t *testing.T) { fs, err := source.NewFiles( ctx, nil, + testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true), true, @@ -173,6 +175,7 @@ func TestFiles_NewReader(t *testing.T) { fs, err := source.NewFiles( ctx, nil, + testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true), true, @@ -281,3 +284,10 @@ func TestFiles_Size(t *testing.T) { require.NoError(t, err) require.Equal(t, wantSize, gotSize2) } + +func TestPruneEmptyDirTree(t *testing.T) { + const dir = "/Users/neilotoole/Library/Caches/sq/f36ac695" + count, err := source.PruneEmptyDirTree(context.Background(), dir) + t.Logf("pruned %d empty dirs from: %s", count, dir) + assert.NoError(t, err) +} diff --git a/libsq/source/internal_test.go b/libsq/source/internal_test.go index b2524a4b9..d7a8ca72c 100644 --- a/libsq/source/internal_test.go +++ b/libsq/source/internal_test.go @@ -17,6 +17,7 @@ var ( return fs.detectType(ctx, "@test", loc) } GroupsFilterOnlyDirectChildren = groupsFilterOnlyDirectChildren + PruneEmptyDirTree = pruneEmptyDirTree ) func TestParseLoc(t *testing.T) { diff --git a/testh/testh.go b/testh/testh.go index 14e9e14b3..d95640e1b 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -14,6 +14,8 @@ import ( "testing" "time" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -127,6 +129,9 @@ func New(t testing.TB, opts ...Option) *Helper { // situations with running tests in parallel with caching enabled, // due to the fact that caching uses pid-based locking, and parallel tests // share the same pid. + // + // REVISIT: The above statement regarding pid-based locking may no longer + // be applicable, as a new cache dir is created for each test run. o := options.Options{driver.OptIngestCache.Key(): false} h.Context = options.NewContext(h.Context, o) t.Cleanup(h.Close) @@ -163,6 +168,7 @@ func (h *Helper) init() { h.files, err = source.NewFiles( h.Context, optRegistry, + TempLockFunc(h.T), tu.TempDir(h.T, false), tu.TempDir(h.T, false), true, @@ -365,7 +371,7 @@ func (h *Helper) NewCollection(handles ...string) *source.Collection { return coll } -// Open opens a driver.Grip for src via h's internal Sources +// Open opens a driver.Grip for src via h's internal Grips // instance: thus subsequent calls to Open may return the // same driver.Grip instance. The opened driver.Grip will be closed // during h.Close. @@ -373,7 +379,7 @@ func (h *Helper) Open(src *source.Source) driver.Grip { ctx, cancelFn := context.WithTimeout(h.Context, h.dbOpenTimeout) defer cancelFn() - grip, err := h.Sources().Open(ctx, src) + grip, err := h.Grips().Open(ctx, src) require.NoError(h.T, err) db, err := grip.DB(ctx) @@ -754,13 +760,13 @@ func (h *Helper) IsMonotable(src *source.Source) bool { return h.DriverFor(src).DriverMetadata().Monotable } -// Sources returns the helper's driver.Grips instance. -func (h *Helper) Sources() *driver.Grips { +// Grips returns the helper's [driver.Grips] instance. +func (h *Helper) Grips() *driver.Grips { h.init() return h.grips } -// Files returns the helper's Files instance. +// Files returns the helper's [source.Files] instance. func (h *Helper) Files() *source.Files { h.init() return h.files @@ -768,7 +774,7 @@ func (h *Helper) Files() *source.Files { // SourceMetadata returns metadata for src. func (h *Helper) SourceMetadata(src *source.Source) (*metadata.Source, error) { - grip, err := h.Sources().Open(h.Context, src) + grip, err := h.Grips().Open(h.Context, src) if err != nil { return nil, err } @@ -778,7 +784,7 @@ func (h *Helper) SourceMetadata(src *source.Source) (*metadata.Source, error) { // TableMetadata returns metadata for src's table. func (h *Helper) TableMetadata(src *source.Source, tbl string) (*metadata.Table, error) { - grip, err := h.Sources().Open(h.Context, src) + grip, err := h.Grips().Open(h.Context, src) if err != nil { return nil, err } @@ -886,3 +892,23 @@ func SetBuildVersion(t testing.TB, vers string) { buildinfo.Version = prevVers }) } + +func TempLockfile(t testing.TB) lockfile.Lockfile { + return lockfile.Lockfile(tu.TempFile(t, "pid.lock", false)) +} + +func TempLockFunc(t testing.TB) lockfile.LockFunc { + return func(ctx context.Context) (unlock func(), err error) { + lock := lockfile.Lockfile(tu.TempFile(t, "pid.lock", false)) + timeout := config.OptConfigLockTimeout.Default() + if err = lock.Lock(ctx, timeout); err != nil { + return nil, err + } + + return func() { + if err := lock.Unlock(); err != nil { + t.Logf("failed to release temp lock: %v", err) + } + }, nil + } +} diff --git a/testh/tu/tu.go b/testh/tu/tu.go index 8ed9ef7b0..ed99e5c35 100644 --- a/testh/tu/tu.go +++ b/testh/tu/tu.go @@ -396,6 +396,14 @@ func TempDir(t testing.TB, clean bool) string { return fp } +// TempFile returns the path to a temp file with the given name, in a unique +// temp dir. The file is not created. If arg clean is true, the parent temp +// dir is created via t.TempDir, and thus is deleted on test cleanup. +func TempFile(t testing.TB, name string, clean bool) string { + fp := filepath.Join(TempDir(t, clean), name) + return fp +} + // CacheDir is the standard means for obtaining a cache dir for tests. // If arg clean is true, the cache dir is created via t.TempDir, and // thus is deleted on test cleanup. From e2231afc35e1f77e4c919bb809ec13259fd07a6f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 12 Jan 2024 21:16:20 -0700 Subject: [PATCH 167/195] Files.doCacheClearSource no longer fails if the cache dir is empty --- cli/cmd_inspect_test.go | 4 + cli/run.go | 166 +++++++++++++++++++++++----------------- libsq/source/cache.go | 28 ++----- testh/testh.go | 6 +- 4 files changed, 109 insertions(+), 95 deletions(-) diff --git a/cli/cmd_inspect_test.go b/cli/cmd_inspect_test.go index dbbeea2d8..bbe7b5796 100644 --- a/cli/cmd_inspect_test.go +++ b/cli/cmd_inspect_test.go @@ -281,6 +281,8 @@ func TestCmdInspect_smoke(t *testing.T) { } func TestCmdInspect_stdin(t *testing.T) { + t.Parallel() + testCases := []struct { fpath string wantErr bool @@ -303,6 +305,8 @@ func TestCmdInspect_stdin(t *testing.T) { tc := tc t.Run(tu.Name(tc.fpath), func(t *testing.T) { + t.Parallel() + ctx := context.Background() f, err := os.Open(tc.fpath) // No need to close f require.NoError(t, err) diff --git a/cli/run.go b/cli/run.go index 924a9b4e9..925b5a9bc 100644 --- a/cli/run.go +++ b/cli/run.go @@ -8,6 +8,8 @@ import ( "path/filepath" "time" + "github.com/neilotoole/sq/cli/flag" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/lockfile" @@ -16,7 +18,6 @@ import ( "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/cli/config/yamlstore" v0_34_0 "github.com/neilotoole/sq/cli/config/yamlstore/upgrades/v0.34.0" //nolint:revive - "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/drivers/csv" "github.com/neilotoole/sq/drivers/json" @@ -56,7 +57,7 @@ func getRun(cmd *cobra.Command) *run.Run { // newRun returns a run.Run configured with standard values for logging, // config, etc. This effectively is the bootstrap mechanism for sq. // Note that the run.Run is not fully configured for use by a command -// until preRun is executed on it. +// until preRun and FinishRunInit are executed on it. // // Note: This func always returns a Run, even if an error occurs during // bootstrap of the Run (for example if there's a config error). We do this @@ -91,7 +92,6 @@ func newRun(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args ru.Cleanup = cleanup.New() // FIXME: re-enable log closing ru.LogCloser = logCloser - _ = logCloser if logErr != nil { stderrLog, h := stderrLogger() _ = logbuf.Flush(ctx, h) @@ -122,6 +122,71 @@ func newRun(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args return ru, log, nil } +// preRun is invoked by cobra prior to the command's RunE being +// invoked. It sets up the driver registry, databases, writers and related +// fundamental components. Subsequent invocations of this method +// are no-op. +func preRun(cmd *cobra.Command, ru *run.Run) error { + if ru == nil { + return errz.New("Run is nil") + } + + if ru.Writers != nil { + // If ru.Writers is already set, then this function has already been + // called on ru. That's ok, just return. + return nil + } + + ctx := cmd.Context() + if ru.Cleanup == nil { + ru.Cleanup = cleanup.New() + } + + // If the --output=/some/file flag is set, then we need to + // override ru.Out (which is typically stdout) to point it at + // the output destination file. + if cmdFlagChanged(ru.Cmd, flag.Output) { + fpath, _ := ru.Cmd.Flags().GetString(flag.Output) + fpath, err := filepath.Abs(fpath) + if err != nil { + return errz.Wrapf(err, "failed to get absolute path for --%s", flag.Output) + } + + // Ensure the parent dir exists + err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm) + if err != nil { + return errz.Wrapf(err, "failed to make parent dir for --%s", flag.Output) + } + + f, err := os.Create(fpath) + if err != nil { + return errz.Wrapf(err, "failed to open file specified by flag --%s", flag.Output) + } + + ru.Cleanup.AddC(f) // Make sure the file gets closed eventually + ru.Out = f + } + + cmdOpts, err := getOptionsFromCmd(ru.Cmd) + if err != nil { + return err + } + ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, ru.Cleanup, cmdOpts, ru.Out, ru.ErrOut) + + if err = FinishRunInit(ctx, ru); err != nil { + return err + } + + if cmdRequiresConfigLock(cmd) { + var unlock func() + if unlock, err = lockReloadConfig(cmd); err != nil { + return err + } + ru.Cleanup.Add(unlock) + } + return nil +} + // FinishRunInit finishes setting up ru. // // TODO: This run.Run initialization mechanism is a bit of a mess. @@ -153,7 +218,11 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { if cfgLock, err = ru.ConfigStore.Lockfile(); err != nil { return err } - cfgLockFunc := source.NewLockFunc(cfgLock, "acquire config lock", config.OptConfigLockTimeout) + cfgLockFunc := newProgressLockFunc( + cfgLock, + "acquire config lock", + config.OptConfigLockTimeout.Get(options.FromContext(ctx)), + ) // We use cache and temp dirs with paths based on a hash of the config's // location. This ensures that multiple sq instances using different @@ -218,7 +287,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { errs := userdriver.ValidateDriverDef(udd) if len(errs) > 0 { - err := errz.Combine(errs...) + err = errz.Combine(errs...) err = errz.Wrapf(err, "failed validation of user driver definition [%d] {%s} from config", i, udd.Name) return err @@ -247,71 +316,6 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { return nil } -// preRun is invoked by cobra prior to the command's RunE being -// invoked. It sets up the driver registry, databases, writers and related -// fundamental components. Subsequent invocations of this method -// are no-op. -func preRun(cmd *cobra.Command, ru *run.Run) error { - if ru == nil { - return errz.New("Run is nil") - } - - if ru.Writers != nil { - // If ru.Writers is already set, then this function has already been - // called on ru. That's ok, just return. - return nil - } - - ctx := cmd.Context() - if ru.Cleanup == nil { - ru.Cleanup = cleanup.New() - } - - // If the --output=/some/file flag is set, then we need to - // override ru.Out (which is typically stdout) to point it at - // the output destination file. - if cmdFlagChanged(ru.Cmd, flag.Output) { - fpath, _ := ru.Cmd.Flags().GetString(flag.Output) - fpath, err := filepath.Abs(fpath) - if err != nil { - return errz.Wrapf(err, "failed to get absolute path for --%s", flag.Output) - } - - // Ensure the parent dir exists - err = os.MkdirAll(filepath.Dir(fpath), os.ModePerm) - if err != nil { - return errz.Wrapf(err, "failed to make parent dir for --%s", flag.Output) - } - - f, err := os.Create(fpath) - if err != nil { - return errz.Wrapf(err, "failed to open file specified by flag --%s", flag.Output) - } - - ru.Cleanup.AddC(f) // Make sure the file gets closed eventually - ru.Out = f - } - - cmdOpts, err := getOptionsFromCmd(ru.Cmd) - if err != nil { - return err - } - ru.Writers, ru.Out, ru.ErrOut = newWriters(ru.Cmd, ru.Cleanup, cmdOpts, ru.Out, ru.ErrOut) - - if err = FinishRunInit(ctx, ru); err != nil { - return err - } - - if cmdRequiresConfigLock(cmd) { - var unlock func() - if unlock, err = lockReloadConfig(cmd); err != nil { - return err - } - ru.Cleanup.Add(unlock) - } - return nil -} - // markCmdRequiresConfigLock marks cmd as requiring a config lock. // Thus, before the command's RunE is invoked, the config lock // is acquired (in preRun), and released on cleanup. @@ -389,3 +393,25 @@ func lockReloadConfig(cmd *cobra.Command) (unlock func(), err error) { } }, nil } + +// newProgressLockFunc returns a new lockfile.LockFunc that that acquires lock, +// and displays a progress bar while doing so. +func newProgressLockFunc(lock lockfile.Lockfile, msg string, timeout time.Duration) lockfile.LockFunc { + return func(ctx context.Context) (unlock func(), err error) { + bar := progress.FromContext(ctx).NewTimeoutWaiter( + msg, + time.Now().Add(timeout), + ) + err = lock.Lock(ctx, timeout) + bar.Stop() + if err != nil { + return nil, errz.Wrap(err, msg) + } + return func() { + if err = lock.Unlock(); err != nil { + lg.FromContext(ctx).With(lga.Lock, lock, "for", msg). + Warn("Failed to release lock", lga.Err, err) + } + }, nil + } +} diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 03137d117..c30e22840 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -314,6 +314,10 @@ func (fs *Files) doCacheClearSource(ctx context.Context, src *Source, clearDownl return err } + if !ioz.DirExists(cacheDir) { + return nil + } + entries, err := os.ReadDir(cacheDir) if err != nil { return errz.Wrapf(err, "%s: clear cache", src.Handle) @@ -456,29 +460,9 @@ func DefaultTempDir() (dir string) { return filepath.Join(os.TempDir(), "sq") } -func NewLockFunc(lock lockfile.Lockfile, msg string, timeoutOpt options.Duration) lockfile.LockFunc { - return func(ctx context.Context) (unlock func(), err error) { - lockTimeout := timeoutOpt.Get(options.FromContext(ctx)) - bar := progress.FromContext(ctx).NewTimeoutWaiter( - msg, - time.Now().Add(lockTimeout), - ) - err = lock.Lock(ctx, lockTimeout) - bar.Stop() - if err != nil { - return nil, errz.Wrap(err, msg) - } - return func() { - if err = lock.Unlock(); err != nil { - lg.FromContext(ctx).With(lga.Lock, lock, "for", msg). - Warn("Failed to release lock", lga.Err, err) - } - }, nil - } -} - // pruneEmptyDirTree prunes empty dirs, and dirs that contain only -// other empty dirs, from the directory tree rooted at dir. Arg dir +// other empty dirs, from the directory tree rooted at dir. If a dir +// contains at least one non-dir entry, that dir is spared. Arg dir // must be an absolute path. func pruneEmptyDirTree(ctx context.Context, dir string) (count int, err error) { return doPruneEmptyDirTree(ctx, dir, true) diff --git a/testh/testh.go b/testh/testh.go index d95640e1b..b79207d99 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -843,14 +843,14 @@ func mustLoadCollection(ctx context.Context, t testing.TB) *source.Collection { return []byte(proj.Expand(string(data))), nil } - fs := &yamlstore.Store{ + store := &yamlstore.Store{ Path: proj.Rel(testsrc.PathSrcsConfig), OptionsRegistry: &options.Registry{}, HookLoad: hookExpand, } - cli.RegisterDefaultOpts(fs.OptionsRegistry) + cli.RegisterDefaultOpts(store.OptionsRegistry) - cfg, err := fs.Load(ctx) + cfg, err := store.Load(ctx) require.NoError(t, err) require.NotNil(t, cfg) require.NotNil(t, cfg.Collection) From 0badff869a2908fbe3436deb4dbc679ea02fe46f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 12 Jan 2024 21:17:40 -0700 Subject: [PATCH 168/195] terminal_windows.go issue, again --- cli/terminal_windows.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cli/terminal_windows.go b/cli/terminal_windows.go index 9a33dd6d0..447683c4c 100644 --- a/cli/terminal_windows.go +++ b/cli/terminal_windows.go @@ -4,6 +4,7 @@ import ( "io" "os" + "golang.org/x/sys/windows" "golang.org/x/term" ) From 28e007e84ca0c5cdbe8ff172c51056953c4fa011 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 12 Jan 2024 21:31:32 -0700 Subject: [PATCH 169/195] Removed temp TestPruneEmptyDirTree test --- libsq/core/ioz/ioz.go | 69 +++++++++++++++++++++++++++++++++++ libsq/source/cache.go | 68 ---------------------------------- libsq/source/files_test.go | 7 ---- libsq/source/internal_test.go | 1 - 4 files changed, 69 insertions(+), 76 deletions(-) diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index be0ac82dd..e1a503355 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -469,3 +469,72 @@ func (w *writeErrorCloser) Error(err error) { func NewFuncWriteErrorCloser(w io.WriteCloser, fn func(error)) WriteErrorCloser { return &writeErrorCloser{WriteCloser: w, fn: fn} } + +// PruneEmptyDirTree prunes empty dirs, and dirs that contain only +// other empty dirs, from the directory tree rooted at dir. If a dir +// contains at least one non-dir entry, that dir is spared. Arg dir +// must be an absolute path. +func PruneEmptyDirTree(ctx context.Context, dir string) (count int, err error) { + return doPruneEmptyDirTree(ctx, dir, true) +} + +func doPruneEmptyDirTree(ctx context.Context, dir string, isRoot bool) (count int, err error) { + if !filepath.IsAbs(dir) { + return 0, errz.Errorf("dir must be absolute: %s", dir) + } + + select { + case <-ctx.Done(): + return 0, errz.Err(ctx.Err()) + default: + } + + var entries []os.DirEntry + if entries, err = os.ReadDir(dir); err != nil { + return 0, errz.Err(err) + } + + if len(entries) == 0 { + if isRoot { + return 0, nil + } + err = os.RemoveAll(dir) + if err != nil { + return 0, errz.Err(err) + } + return 1, nil + } + + // We've got some entries... let's check what they are. + if countNonDirs(entries) != 0 { + // There are some non-dir entries, so this dir doesn't get deleted. + return 0, nil + } + + // Each of the entries is a dir. Recursively prune. + var n int + for _, entry := range entries { + select { + case <-ctx.Done(): + return count, errz.Err(ctx.Err()) + default: + } + + n, err = doPruneEmptyDirTree(ctx, filepath.Join(dir, entry.Name()), false) + count += n + if err != nil { + return count, err + } + } + + return count, nil +} + +func countNonDirs(entries []os.DirEntry) (count int) { + for _, entry := range entries { + if !entry.IsDir() { + count++ + } + } + return count +} diff --git a/libsq/source/cache.go b/libsq/source/cache.go index c30e22840..cbb7f833e 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -460,71 +460,3 @@ func DefaultTempDir() (dir string) { return filepath.Join(os.TempDir(), "sq") } -// pruneEmptyDirTree prunes empty dirs, and dirs that contain only -// other empty dirs, from the directory tree rooted at dir. If a dir -// contains at least one non-dir entry, that dir is spared. Arg dir -// must be an absolute path. -func pruneEmptyDirTree(ctx context.Context, dir string) (count int, err error) { - return doPruneEmptyDirTree(ctx, dir, true) -} - -func doPruneEmptyDirTree(ctx context.Context, dir string, isRoot bool) (count int, err error) { - if !filepath.IsAbs(dir) { - return 0, errz.Errorf("dir must be absolute: %s", dir) - } - - select { - case <-ctx.Done(): - return 0, errz.Err(ctx.Err()) - default: - } - - var entries []os.DirEntry - if entries, err = os.ReadDir(dir); err != nil { - return 0, errz.Err(err) - } - - if len(entries) == 0 { - if isRoot { - return 0, nil - } - err = os.RemoveAll(dir) - if err != nil { - return 0, errz.Err(err) - } - return 1, nil - } - - // We've got some entries... let's check what they are. - if countNonDirs(entries) != 0 { - // There are some non-dir entries, so this dir doesn't get deleted. - return 0, nil - } - - // Each of the entries is a dir. Recursively prune. - var n int - for _, entry := range entries { - select { - case <-ctx.Done(): - return count, errz.Err(ctx.Err()) - default: - } - - n, err = doPruneEmptyDirTree(ctx, filepath.Join(dir, entry.Name()), false) - count += n - if err != nil { - return count, err - } - } - - return count, nil -} - -func countNonDirs(entries []os.DirEntry) (count int) { - for _, entry := range entries { - if !entry.IsDir() { - count++ - } - } - return count -} diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index e4af7fb47..729ca2324 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -284,10 +284,3 @@ func TestFiles_Size(t *testing.T) { require.NoError(t, err) require.Equal(t, wantSize, gotSize2) } - -func TestPruneEmptyDirTree(t *testing.T) { - const dir = "/Users/neilotoole/Library/Caches/sq/f36ac695" - count, err := source.PruneEmptyDirTree(context.Background(), dir) - t.Logf("pruned %d empty dirs from: %s", count, dir) - assert.NoError(t, err) -} diff --git a/libsq/source/internal_test.go b/libsq/source/internal_test.go index d7a8ca72c..b2524a4b9 100644 --- a/libsq/source/internal_test.go +++ b/libsq/source/internal_test.go @@ -17,7 +17,6 @@ var ( return fs.detectType(ctx, "@test", loc) } GroupsFilterOnlyDirectChildren = groupsFilterOnlyDirectChildren - PruneEmptyDirTree = pruneEmptyDirTree ) func TestParseLoc(t *testing.T) { From 402e20c641b8b3610df682be634186ac522b4fed Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 12 Jan 2024 22:15:28 -0700 Subject: [PATCH 170/195] More sqlite file closing --- drivers/userdriver/xmlud/xmlimport.go | 50 ++++++++++------------ drivers/userdriver/xmlud/xmlimport_test.go | 22 +++++++--- libsq/driver/grips.go | 5 ++- libsq/source/cache.go | 1 - testh/testh.go | 5 ++- 5 files changed, 47 insertions(+), 36 deletions(-) diff --git a/drivers/userdriver/xmlud/xmlimport.go b/drivers/userdriver/xmlud/xmlimport.go index 5d1cc8888..e28883c53 100644 --- a/drivers/userdriver/xmlud/xmlimport.go +++ b/drivers/userdriver/xmlud/xmlimport.go @@ -1,6 +1,8 @@ // Package xmlud provides user driver XML import functionality. // Note that this implementation is experimental, not well-tested, // inefficient, possibly incomprehensible, and subject to change. +// +// Also, it's really old, and just generally embarrassing. Don't look. package xmlud import ( @@ -12,6 +14,8 @@ import ( "strconv" "strings" + "github.com/neilotoole/sq/libsq/core/sqlz" + "github.com/neilotoole/sq/drivers/userdriver" "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" @@ -32,8 +36,17 @@ func Import(ctx context.Context, def *userdriver.DriverDef, data io.Reader, dest return errz.Errorf("xmlud.Import does not support genre {%s}", def.Genre) } + log := lg.FromContext(ctx) + db, err := destGrip.DB(ctx) + if err != nil { + return err + } + im := &importer{ - log: lg.FromContext(ctx), + log: log, + destGrip: destGrip, + destDB: db, + data: data, def: def, selStack: newSelStack(), rowStack: newRowStack(), @@ -45,13 +58,12 @@ func Import(ctx context.Context, def *userdriver.DriverDef, data io.Reader, dest msgOnce: map[string]struct{}{}, } - err := im.execImport(ctx, data, destGrip) - err2 := im.clnup.Run() - if err != nil { - return errz.Wrap(err, "xml import") + if err = im.execIngest(ctx); err != nil { + lg.WarnIfFuncError(log, "xml ingest: cleanup", im.clnup.Run) + return errz.Wrap(err, "xml ingest") } - return errz.Wrap(err2, "xml import: cleanup") + return errz.Wrap(im.clnup.Run(), "xml ingest: cleanup") } // importer does the work of importing data from XML. @@ -60,6 +72,7 @@ type importer struct { def *userdriver.DriverDef data io.Reader destGrip driver.Grip + destDB sqlz.DB selStack *selStack rowStack *rowStack tblDefs map[string]*sqlmodel.TableDef @@ -86,9 +99,7 @@ type importer struct { msgOnce map[string]struct{} } -func (im *importer) execImport(ctx context.Context, r io.Reader, destGrip driver.Grip) error { //nolint:gocognit - im.data, im.destGrip = r, destGrip - +func (im *importer) execIngest(ctx context.Context) error { //nolint:gocognit err := im.createTables(ctx) if err != nil { return err @@ -429,13 +440,8 @@ func (im *importer) dbInsert(ctx context.Context, row *rowState) error { execInsertFn, ok := im.execInsertFns[cacheKey] if !ok { - db, err := im.destGrip.DB(ctx) - if err != nil { - return err - } - // Nothing cached, prepare the insert statement and insert munge func - stmtExecer, err := im.destGrip.SQLDriver().PrepareInsertStmt(ctx, db, tblName, colNames, 1) + stmtExecer, err := im.destGrip.SQLDriver().PrepareInsertStmt(ctx, im.destDB, tblName, colNames, 1) if err != nil { return err } @@ -506,13 +512,8 @@ func (im *importer) dbUpdate(ctx context.Context, row *rowState) error { cacheKey := "##update_func__" + tblName + "__" + strings.Join(colNames, ",") + whereClause execUpdateFn, ok := im.execUpdateFns[cacheKey] if !ok { - db, err := im.destGrip.DB(ctx) - if err != nil { - return err - } - // Nothing cached, prepare the update statement and munge func - stmtExecer, err := drvr.PrepareUpdateStmt(ctx, db, tblName, colNames, whereClause) + stmtExecer, err := drvr.PrepareUpdateStmt(ctx, im.destDB, tblName, colNames, whereClause) if err != nil { return err } @@ -576,12 +577,7 @@ func (im *importer) createTables(ctx context.Context) error { im.tblDefs[tblDef.Name] = tblDef - db, err := im.destGrip.DB(ctx) - if err != nil { - return err - } - - err = im.destGrip.SQLDriver().CreateTable(ctx, db, tblDef) + err = im.destGrip.SQLDriver().CreateTable(ctx, im.destDB, tblDef) if err != nil { return err } diff --git a/drivers/userdriver/xmlud/xmlimport_test.go b/drivers/userdriver/xmlud/xmlimport_test.go index 4e21e70ec..ee816f29f 100644 --- a/drivers/userdriver/xmlud/xmlimport_test.go +++ b/drivers/userdriver/xmlud/xmlimport_test.go @@ -4,6 +4,8 @@ import ( "bytes" "testing" + "github.com/neilotoole/sq/testh/tu" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -22,6 +24,10 @@ const ( ) func TestImport_Ppl(t *testing.T) { + t.Cleanup(func() { + tu.OpenFileCount(t, true) + }) + th := testh.New(t) ext := &config.Ext{} @@ -31,23 +37,27 @@ func TestImport_Ppl(t *testing.T) { require.Equal(t, driverPpl, udDef.Name) require.Equal(t, xmlud.Genre, udDef.Genre) - scratchDB, err := th.Grips().OpenEphemeral(th.Context) + grip, err := th.Grips().OpenEphemeral(th.Context) require.NoError(t, err) t.Cleanup(func() { - assert.NoError(t, scratchDB.Close()) + assert.NoError(t, grip.Close()) }) + tu.OpenFileCount(t, true) + data := proj.ReadFile("drivers/userdriver/xmlud/testdata/people.xml") - err = xmlud.Import(th.Context, udDef, bytes.NewReader(data), scratchDB) + err = xmlud.Import(th.Context, udDef, bytes.NewReader(data), grip) require.NoError(t, err) - srcMeta, err := scratchDB.SourceMetadata(th.Context, false) + tu.OpenFileCount(t, true) + + srcMeta, err := grip.SourceMetadata(th.Context, false) require.NoError(t, err) require.Equal(t, 2, len(srcMeta.Tables)) require.Equal(t, "person", srcMeta.Tables[0].Name) require.Equal(t, "skill", srcMeta.Tables[1].Name) - sink, err := th.QuerySQL(scratchDB.Source(), nil, "SELECT * FROM person") + sink, err := th.QuerySQL(grip.Source(), nil, "SELECT * FROM person") require.NoError(t, err) require.Equal(t, 3, len(sink.Recs)) require.Equal(t, "Nikola", stringz.Val(sink.Recs[0][1])) @@ -56,7 +66,7 @@ func TestImport_Ppl(t *testing.T) { require.Equal(t, int64(i+1), stringz.Val(rec[0])) } - sink, err = th.QuerySQL(scratchDB.Source(), nil, "SELECT * FROM skill") + sink, err = th.QuerySQL(grip.Source(), nil, "SELECT * FROM skill") require.NoError(t, err) require.Equal(t, 6, len(sink.Recs)) require.Equal(t, "Electrifying", stringz.Val(sink.Recs[0][2])) diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 8849efc8d..b3df4cb42 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -174,6 +174,7 @@ func (gs *Grips) OpenEphemeral(ctx context.Context) (Grip, error) { } gs.clnup.AddC(g) log.Info("Opened ephemeral db", lga.Src, g.Source()) + gs.grips[g.Source().Handle] = g return g, nil } @@ -441,7 +442,9 @@ type cleanOnCloseGrip struct { // method, and then the closeFn, returning a combined error. func (g *cleanOnCloseGrip) Close() error { g.once.Do(func() { - g.closeErr = errz.Append(g.Grip.Close(), g.clnup.Run()) + err1 := g.Grip.Close() + err2 := g.clnup.Run() + g.closeErr = errz.Append(err1, err2) }) return g.closeErr } diff --git a/libsq/source/cache.go b/libsq/source/cache.go index cbb7f833e..060d32d55 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -459,4 +459,3 @@ func DefaultCacheDir() (dir string) { func DefaultTempDir() (dir string) { return filepath.Join(os.TempDir(), "sq") } - diff --git a/testh/testh.go b/testh/testh.go index b79207d99..8a999211c 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -182,7 +182,10 @@ func (h *Helper) init() { }) h.grips = driver.NewGrips(h.registry, h.files, sqlite3.NewScratchSource) - h.Cleanup.AddC(h.grips) + // h.Cleanup.AddC(h.grips) + h.Cleanup.AddE(func() error { + return h.grips.Close() + }) h.registry.AddProvider(sqlite3.Type, &sqlite3.Provider{Log: log}) h.registry.AddProvider(postgres.Type, &postgres.Provider{Log: log}) From d2a6840dc271d4df2c3dff2e186538edf521b7a9 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 12 Jan 2024 22:29:13 -0700 Subject: [PATCH 171/195] More sqlite file closing, but it's already working --- libsq/driver/grips.go | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index b3df4cb42..187181592 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -71,6 +71,7 @@ func (gs *Grips) Open(ctx context.Context, src *source.Source) (Grip, error) { return nil, err } gs.clnup.AddC(g) + gs.grips[src.Handle] = g return g, nil } @@ -97,15 +98,9 @@ func (gs *Grips) IsSQLSource(src *source.Source) bool { return false } -func (gs *Grips) getKey(src *source.Source) string { - return src.Handle -} - func (gs *Grips) doOpen(ctx context.Context, src *source.Source) (Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) - key := gs.getKey(src) - - grip, ok := gs.grips[key] + grip, ok := gs.grips[src.Handle] if ok { return grip, nil } @@ -124,7 +119,6 @@ func (gs *Grips) doOpen(ctx context.Context, src *source.Source) (Grip, error) { return nil, err } - gs.grips[key] = grip return grip, nil } @@ -418,6 +412,7 @@ func (gs *Grips) OpenJoin(ctx context.Context, srcs ...*source.Source) (Grip, er clnup: clnup, } gs.clnup.AddC(g) + gs.grips[g.Source().Handle] = g return g, nil } From 738751ab90b71d4042390a87db75b2371016907e Mon Sep 17 00:00:00 2001 From: neilotoole Date: Fri, 12 Jan 2024 23:06:55 -0700 Subject: [PATCH 172/195] sqlite type test cleanup --- drivers/sqlite3/db_type_test.go | 1 - drivers/sqlite3/sqlite3.go | 3 +- drivers/sqlite3/testdata/type_test.ddl | 40 +++++++++++------------ libsq/driver/driver.go | 3 ++ libsq/source/cache.go | 44 ++++---------------------- 5 files changed, 31 insertions(+), 60 deletions(-) diff --git a/drivers/sqlite3/db_type_test.go b/drivers/sqlite3/db_type_test.go index 4024e41c6..4b2951867 100644 --- a/drivers/sqlite3/db_type_test.go +++ b/drivers/sqlite3/db_type_test.go @@ -148,7 +148,6 @@ func createTypeTestTbls(th *testh.Helper, src *source.Source, nTimes int, withDa } insertStmt := fmt.Sprintf(insertTpl, actualTblName, strings.Join(typeTestColNames, ", "), placeholders) - for _, insertVals := range typeTestVals { _, err = db.Exec(insertStmt, insertVals...) require.NoError(t, err) diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index 2ef66f1a5..2650da3ca 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -1034,7 +1034,8 @@ func NewScratchSource(ctx context.Context, fpath string) (src *source.Source, cl clnup = func() error { log.Debug("Delete sqlite3 scratchdb file", lga.Src, src, lga.Path, fpath) - if err := os.Remove(fpath); err != nil { + lg.WarnIfError(log, "Delete sqlite3 db journal file", os.RemoveAll(filepath.Join(fpath, ".db-journal"))) + if err := os.RemoveAll(fpath); err != nil { log.Warn("Delete sqlite3 scratchdb file", lga.Err, err) return errz.Err(err) } diff --git a/drivers/sqlite3/testdata/type_test.ddl b/drivers/sqlite3/testdata/type_test.ddl index b943667fd..68f2d23b4 100644 --- a/drivers/sqlite3/testdata/type_test.ddl +++ b/drivers/sqlite3/testdata/type_test.ddl @@ -1,22 +1,22 @@ CREATE TABLE type_test ( - col_id INTEGER NOT NULL PRIMARY KEY, - col_int INT NOT NULL, - col_int_n INT, - col_double REAL NOT NULL, - col_double_n REAL, - col_boolean BOOLEAN DEFAULT FALSE NOT NULL, - col_boolean_n BOOLEAN, - col_text TEXT NOT NULL, - col_text_n TEXT, - col_blob BLOB NOT NULL, - col_blob_n BLOB, - col_datetime DATETIME DEFAULT '1970-01-01 00:00:00' NOT NULL, - col_datetime_n DATETIME, - col_date DATE DEFAULT '1970-01-01' NOT NULL, - col_date_n DATE, - col_time TIME NOT NULL, - col_time_n TIME, - col_decimal DECIMAL DEFAULT 0 NOT NULL, - col_decimal_n DECIMAL -) \ No newline at end of file + col_id INTEGER NOT NULL PRIMARY KEY, + col_int INT NOT NULL, + col_int_n INT, + col_double REAL NOT NULL, + col_double_n REAL, + col_boolean BOOLEAN DEFAULT FALSE NOT NULL, + col_boolean_n BOOLEAN, + col_text TEXT NOT NULL, + col_text_n TEXT, + col_blob BLOB NOT NULL, + col_blob_n BLOB, + col_datetime DATETIME DEFAULT '1970-01-01 00:00:00' NOT NULL, + col_datetime_n DATETIME, + col_date DATE DEFAULT '1970-01-01' NOT NULL, + col_date_n DATE, + col_time TIME NOT NULL, + col_time_n TIME, + col_decimal DECIMAL(10,5) DEFAULT 0 NOT NULL, + col_decimal_n DECIMAL(10,5) +) diff --git a/libsq/driver/driver.go b/libsq/driver/driver.go index 9ad19096d..e9b33fcc6 100644 --- a/libsq/driver/driver.go +++ b/libsq/driver/driver.go @@ -218,6 +218,9 @@ type Metadata struct { // OpeningPing is a standardized mechanism to ping db using // driver.OptConnOpenTimeout. This should be invoked by each SQL // driver impl in its Open method. If the ping fails, db is closed. +// +// REVISIT: now that driver.OptConnOpenTimeout is applied to each DB +// in each driver impl, do we still need OpeningPing? func OpeningPing(ctx context.Context, src *source.Source, db *sql.DB) error { o := options.FromContext(ctx) timeout := OptConnOpenTimeout.Get(o) diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 060d32d55..c070fed19 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -386,8 +386,7 @@ func (fs *Files) doCacheClearAll(ctx context.Context) error { // REVISIT: This doesn't really do as much as desired. It should // also be able to detect orphaned src cache dirs and delete those. func (fs *Files) doCacheSweep(ctx context.Context) { - dir := fs.cacheDir - log := lg.FromContext(ctx).With(lga.Dir, dir) + log := lg.FromContext(ctx).With(lga.Dir, fs.cacheDir) log.Debug("Sweep cache dir: acquiring config lock") if unlock, err := fs.cfgLockFn(ctx); err != nil { @@ -397,44 +396,13 @@ func (fs *Files) doCacheSweep(ctx context.Context) { defer unlock() } - var count int - err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - if err != nil { - log.Warn("Problem sweeping cache dir", lga.Path, path, lga.Err, err) - return nil - } - - if !info.IsDir() { - return nil - } - - files, err := os.ReadDir(path) - if err != nil { - log.Warn("Problem reading dir", lga.Dir, path, lga.Err, err) - return nil - } - - if len(files) != 0 { - return nil - } - - err = os.Remove(path) - if err != nil { - log.Warn("Problem removing empty dir", lga.Dir, path, lga.Err, err) - } - count++ - - return nil - }) + count, err := ioz.PruneEmptyDirTree(ctx, fs.cacheDir) if err != nil { - log.Warn("Problem sweeping cache dir", lga.Dir, dir, lga.Err, err) + log.Warn("Problem sweeping cache dir", lga.Err, err, "deleted_dirs", count) + return } - log.Info("Swept cache dir", lga.Dir, dir, lga.Count, count) + + log.Info("Swept cache dir", "deleted_dirs", count) } // DefaultCacheDir returns the sq cache dir. This is generally From 17a6944f48245c493da58193b545383d21876538 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 07:10:15 -0700 Subject: [PATCH 173/195] Improved logging --- cli/cli.go | 2 +- cli/cmd_cache.go | 98 ++++++++++++++++++++-------- cli/logging.go | 16 +++++ libsq/core/ioz/download/cache.go | 3 +- libsq/core/ioz/download/http.go | 7 +- libsq/core/lg/devlog/devlog.go | 2 - libsq/core/lg/devlog/tint/handler.go | 66 +++++++++++++++---- 7 files changed, 147 insertions(+), 47 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index c81948780..eb7046d13 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -242,7 +242,7 @@ func newCommandTree(ru *run.Run) (rootCmd *cobra.Command) { cacheCmd := addCmd(ru, rootCmd, newCacheCmd()) addCmd(ru, cacheCmd, newCacheLocationCmd()) - addCmd(ru, cacheCmd, newCacheInfoCmd()) + addCmd(ru, cacheCmd, newCacheStatCmd()) addCmd(ru, cacheCmd, newCacheEnableCmd()) addCmd(ru, cacheCmd, newCacheDisableCmd()) addCmd(ru, cacheCmd, newCacheClearCmd()) diff --git a/cli/cmd_cache.go b/cli/cmd_cache.go index 66ed5826f..129d725dd 100644 --- a/cli/cmd_cache.go +++ b/cli/cmd_cache.go @@ -1,6 +1,7 @@ package cli import ( + "github.com/neilotoole/sq/libsq/core/options" "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/flag" @@ -28,10 +29,14 @@ func newCacheCmd() *cobra.Command { $ sq cache enable + $ sq cache enable @sakila + $ sq cache disable $ sq cache clear + $ sq cache clear @sakila + # Print tree view of cache dir. $ sq cache tree`, } @@ -46,7 +51,10 @@ func newCacheLocationCmd() *cobra.Command { Short: "Print cache location", Long: "Print cache location.", Args: cobra.ExactArgs(0), - RunE: execCacheLocation, + RunE: func(cmd *cobra.Command, args []string) error { + ru := run.FromContext(cmd.Context()) + return ru.Writers.Config.CacheLocation(ru.Files.CacheDir()) + }, Example: ` $ sq cache location /Users/neilotoole/Library/Caches/sq/f36ac695`, } @@ -57,29 +65,25 @@ func newCacheLocationCmd() *cobra.Command { return cmd } -func execCacheLocation(cmd *cobra.Command, _ []string) error { - ru := run.FromContext(cmd.Context()) - return ru.Writers.Config.CacheLocation(ru.Files.CacheDir()) -} - -func newCacheInfoCmd() *cobra.Command { +func newCacheStatCmd() *cobra.Command { cmd := &cobra.Command{ Use: "stat", Short: "Show cache info", Long: "Show cache info, including location and size.", Args: cobra.ExactArgs(0), - RunE: execCacheInfo, + RunE: execCacheStat, Example: ` $ sq cache stat /Users/neilotoole/Library/Caches/sq/f36ac695 enabled (472.8MB)`, } + markCmdRequiresConfigLock(cmd) addTextFormatFlags(cmd) cmd.Flags().BoolP(flag.JSON, flag.JSONShort, false, flag.JSONUsage) cmd.Flags().BoolP(flag.YAML, flag.YAMLShort, false, flag.YAMLUsage) return cmd } -func execCacheInfo(cmd *cobra.Command, _ []string) error { +func execCacheStat(cmd *cobra.Command, _ []string) error { ru := run.FromContext(cmd.Context()) dir := ru.Files.CacheDir() @@ -164,36 +168,76 @@ func execCacheTree(cmd *cobra.Command, _ []string) error { return ioz.PrintTree(ru.Out, cacheDir, showSize, !ru.Writers.Printing.IsMonochrome()) } -func newCacheEnableCmd() *cobra.Command { +func newCacheEnableCmd() *cobra.Command { //nolint:dupl cmd := &cobra.Command{ - Use: "enable", - Short: "Enable caching", - Long: `Enable caching. This is equivalent to: - - $ sq config set ingest.cache true`, - Args: cobra.ExactArgs(0), + Use: "enable [@HANDLE]", + Short: "Enable caching", + Long: `Enable caching by default or for a specific source.`, + Args: cobra.MaximumNArgs(1), + ValidArgsFunction: completeHandle(1), RunE: func(cmd *cobra.Command, args []string) error { - return execConfigSet(cmd, []string{driver.OptIngestCache.Key(), "true"}) + ru := run.FromContext(cmd.Context()) + var o options.Options + + if len(args) == 0 { + o = ru.Config.Options + } else { + src, err := ru.Config.Collection.Get(args[0]) + if err != nil { + return err + } + if src.Options == nil { + src.Options = options.Options{} + } + o = src.Options + } + + o[driver.OptIngestCache.Key()] = true + return ru.ConfigStore.Save(cmd.Context(), ru.Config) }, - Example: ` $ sq cache enable`, + Example: ` # Enable caching by default + $ sq cache enable + + # Enable caching for a particular source + $ sq cache enable @sakila`, } markCmdRequiresConfigLock(cmd) return cmd } -func newCacheDisableCmd() *cobra.Command { +func newCacheDisableCmd() *cobra.Command { //nolint:dupl cmd := &cobra.Command{ - Use: "disable", - Short: "Disable caching", - Long: `Disable caching. This is equivalent to: - - $ sq config set ingest.cache false`, - Args: cobra.ExactArgs(0), + Use: "disable [@HANDLE]", + Short: "Disable caching", + Long: `Disable caching by default or for a specific source.`, + Args: cobra.MaximumNArgs(1), + ValidArgsFunction: completeHandle(1), RunE: func(cmd *cobra.Command, args []string) error { - return execConfigSet(cmd, []string{driver.OptIngestCache.Key(), "false"}) + ru := run.FromContext(cmd.Context()) + var o options.Options + + if len(args) == 0 { + o = ru.Config.Options + } else { + src, err := ru.Config.Collection.Get(args[0]) + if err != nil { + return err + } + if src.Options == nil { + src.Options = options.Options{} + } + o = src.Options + } + + o[driver.OptIngestCache.Key()] = false + return ru.ConfigStore.Save(cmd.Context(), ru.Config) }, - Example: ` $ sq cache disable`, + Example: ` # Disable caching by default + $ sq cache disable + + # Disable caching for a particular source + $ sq cache disable @sakila`, } markCmdRequiresConfigLock(cmd) diff --git a/cli/logging.go b/cli/logging.go index d544a1aa0..bee8bbb4d 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -5,11 +5,14 @@ import ( "fmt" "io" "log/slog" + "net/http" "os" "path/filepath" "strconv" "strings" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/config" @@ -149,6 +152,7 @@ func newJSONHandler(w io.Writer, lvl slog.Leveler) slog.Handler { func slogReplaceAttrs(groups []string, a slog.Attr) slog.Attr { a = slogReplaceSource(groups, a) a = slogReplaceDuration(groups, a) + a = slogReplaceHTTPResponse(groups, a) return a } @@ -175,6 +179,18 @@ func slogReplaceDuration(_ []string, a slog.Attr) slog.Attr { return a } +// slogReplaceDuration prints the friendly version of duration. +func slogReplaceHTTPResponse(_ []string, a slog.Attr) slog.Attr { + resp, ok := a.Value.Any().(*http.Response) + if !ok { + return a + } + + v := httpz.ResponseLogValue(resp) + a.Value = v + return a +} + // logFrom is a convenience function for getting a *slog.Logger from a // *cobra.Command or context.Context. // If no logger present, lg.Discard() is returned. diff --git a/libsq/core/ioz/download/cache.go b/libsq/core/ioz/download/cache.go index 46287a26c..1d5560cb8 100644 --- a/libsq/core/ioz/download/cache.go +++ b/libsq/core/ioz/download/cache.go @@ -14,7 +14,6 @@ import ( "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/ioz/contextio" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" @@ -265,7 +264,7 @@ func (c *cache) write(ctx context.Context, resp *http.Response, return 0, err } - log.Debug("Writing HTTP response to cache", lga.Dir, c.dir, lga.Resp, httpz.ResponseLogValue(resp)) + log.Debug("Writing HTTP response to cache", lga.Dir, c.dir, lga.Resp, resp) fpHeader, fpBody, _ := c.paths(resp.Request) headerBytes, err := httputil.DumpResponse(resp, false) diff --git a/libsq/core/ioz/download/http.go b/libsq/core/ioz/download/http.go index 1b815fc8c..562daebbd 100644 --- a/libsq/core/ioz/download/http.go +++ b/libsq/core/ioz/download/http.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" ) @@ -291,11 +290,13 @@ func varyMatches(cachedResp *http.Response, req *http.Request) bool { func logResp(resp *http.Response, elapsed time.Duration, err error) { ctx := resp.Request.Context() - log := lg.FromContext(ctx).With("response_time", elapsed) + log := lg.FromContext(ctx). + With("response_time", elapsed, lga.Method, resp.Request.Method, lga.URL, resp.Request.URL.String()) if err != nil { log.Warn("HTTP request error", lga.Err, err) return } - log.Info("HTTP request completed", lga.Resp, httpz.ResponseLogValue(resp)) + log.Info("HTTP request completed", lga.Resp, resp) + log.Warn("this is a warning") // FIXME: delete } diff --git a/libsq/core/lg/devlog/devlog.go b/libsq/core/lg/devlog/devlog.go index 4940a41ed..dd49185d9 100644 --- a/libsq/core/lg/devlog/devlog.go +++ b/libsq/core/lg/devlog/devlog.go @@ -20,8 +20,6 @@ func NewHandler(w io.Writer, lvl slog.Leveler) slog.Handler { AddSource: true, ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { switch a.Key { - // case "pid": - // return slog.Attr{} case "error": a.Key = "err" return a diff --git a/libsq/core/lg/devlog/tint/handler.go b/libsq/core/lg/devlog/tint/handler.go index 8c5d2bedc..63250b85f 100644 --- a/libsq/core/lg/devlog/tint/handler.go +++ b/libsq/core/lg/devlog/tint/handler.go @@ -53,11 +53,14 @@ Color support on Windows can be added by using e.g. the [go-colorable] package. package tint import ( + "bytes" "context" "encoding" "fmt" "io" "log/slog" + "net/http" + "net/http/httputil" "path/filepath" "runtime" "strconv" @@ -72,7 +75,6 @@ import ( // ANSI modes // See: https://gist.github.com/JBlond/2fea43a3049b38287e5e9cefc87b2124 const ( - ansiAttr = "\033[36;2m" ansiBlue = "\033[34m" ansiBrightBlue = "\033[94m" ansiBrightGreen = "\033[92m" @@ -83,12 +85,21 @@ const ( ansiBrightRedFaint = "\033[91;2m" ansiBrightYellow = "\033[93m" ansiFaint = "\033[2m" - ansiReset = "\033[0m" - ansiResetFaint = "\033[22m" - ansiStack = "\033[0;35m" ansiYellowBold = "\033[1;33m" - ansiStackErr = ansiYellowBold - ansiStackErrType = ansiBrightGreenFaint + ansiYellow = "\033[33m" + ansiPurpleBold = "\033[1;35m" + + ansiReset = "\033[0m" + ansiResetFaint = "\033[22m" + + ansiAttr = "\033[36;2m" + ansiStack = "\033[0;35m" + ansiStackErr = ansiYellowBold + ansiStackErrType = ansiBrightGreenFaint + ansiDebug = ansiBrightGreen + ansiInfo = ansiYellow + ansiWarn = ansiPurpleBold + ansiError = ansiBrightRedBold ) const errKey = "err" @@ -234,13 +245,13 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { msgColor := ansiBrightGreen switch r.Level { case slog.LevelDebug: - msgColor = ansiBrightGreen + msgColor = ansiDebug case slog.LevelWarn: - msgColor = ansiBrightYellow + msgColor = ansiWarn case slog.LevelError: msgColor = ansiBrightRedBold case slog.LevelInfo: - msgColor = ansiBrightGreenBold + msgColor = ansiInfo } // write message if rep == nil { @@ -262,6 +273,7 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { const keyStack = "stack" var stackAttrs []slog.Attr + var resps []*http.Response // write attributes r.Attrs(func(attr slog.Attr) bool { @@ -270,10 +282,19 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { stackAttrs = append(stackAttrs, attr) return true } + + if resp, ok := attr.Value.Any().(*http.Response); ok { + // Special handling for http responses + resps = append(resps, resp) + return true + } + h.appendAttr(buf, attr, h.groupPrefix, h.groups) return true }) + h.handleHTTPResponse(buf, resps) + if len(*buf) == 0 { return nil } @@ -288,6 +309,25 @@ func (h *handler) Handle(_ context.Context, r slog.Record) error { return err } +func (h *handler) handleHTTPResponse(buf *buffer, resps []*http.Response) { + for _, resp := range resps { + if resp == nil { + continue + } + b, _ := httputil.DumpResponse(resp, false) + b = bytes.TrimSpace(b) + if len(b) == 0 { + return + } + + buf.WriteByte('\n') + buf.WriteStringIf(!h.noColor, ansiAttr) + _, _ = buf.Write(b) + buf.WriteStringIf(!h.noColor, ansiReset) + buf.WriteByte('\n') + } +} + func (h *handler) handleStackAttrs(buf *buffer, attrs []slog.Attr) { if len(attrs) == 0 { return @@ -381,20 +421,22 @@ func (h *handler) appendTime(buf *buffer, t time.Time) { func (h *handler) appendLevel(buf *buffer, level slog.Level) { switch { case level < slog.LevelInfo: + buf.WriteStringIf(!h.noColor, ansiDebug) buf.WriteString("DBG") appendLevelDelta(buf, level-slog.LevelDebug) + buf.WriteStringIf(!h.noColor, ansiReset) case level < slog.LevelWarn: - buf.WriteStringIf(!h.noColor, ansiBrightGreen) + buf.WriteStringIf(!h.noColor, ansiInfo) buf.WriteString("INF") appendLevelDelta(buf, level-slog.LevelInfo) buf.WriteStringIf(!h.noColor, ansiReset) case level < slog.LevelError: - buf.WriteStringIf(!h.noColor, ansiBrightYellow) + buf.WriteStringIf(!h.noColor, ansiWarn) buf.WriteString("WRN") appendLevelDelta(buf, level-slog.LevelWarn) buf.WriteStringIf(!h.noColor, ansiReset) default: - buf.WriteStringIf(!h.noColor, ansiBrightRedBold) + buf.WriteStringIf(!h.noColor, ansiError) buf.WriteString("ERR") appendLevelDelta(buf, level-slog.LevelError) buf.WriteStringIf(!h.noColor, ansiReset) From d7e11c8763253f2fbf9a8865113e917364da4c83 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 07:32:00 -0700 Subject: [PATCH 174/195] Improve Files.Close --- cli/run.go | 1 - cli/testrun/testrun.go | 1 - libsq/core/ioz/download/http.go | 1 - libsq/source/cache.go | 7 ++++-- libsq/source/files.go | 40 ++++++++++++--------------------- libsq/source/files_test.go | 27 +++------------------- testh/testh.go | 1 - 7 files changed, 22 insertions(+), 56 deletions(-) diff --git a/cli/run.go b/cli/run.go index 925b5a9bc..0a9029037 100644 --- a/cli/run.go +++ b/cli/run.go @@ -235,7 +235,6 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { cfgLockFunc, filepath.Join(source.DefaultTempDir(), sum), filepath.Join(source.DefaultCacheDir(), sum), - true, ) if err != nil { lg.WarnIfFuncError(log, lga.Cleanup, ru.Cleanup.Run) diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index 7ad3e851a..465933587 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -121,7 +121,6 @@ func newRun(ctx context.Context, t testing.TB, cfgStore config.Store) (ru *run.R testh.TempLockFunc(t), tu.TempDir(t, false), tu.CacheDir(t, false), - true, ) require.NoError(t, err) diff --git a/libsq/core/ioz/download/http.go b/libsq/core/ioz/download/http.go index 562daebbd..98d219eed 100644 --- a/libsq/core/ioz/download/http.go +++ b/libsq/core/ioz/download/http.go @@ -298,5 +298,4 @@ func logResp(resp *http.Response, elapsed time.Duration, err error) { } log.Info("HTTP request completed", lga.Resp, resp) - log.Warn("this is a warning") // FIXME: delete } diff --git a/libsq/source/cache.go b/libsq/source/cache.go index c070fed19..6d2794357 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -385,8 +385,11 @@ func (fs *Files) doCacheClearAll(ctx context.Context) error { // // REVISIT: This doesn't really do as much as desired. It should // also be able to detect orphaned src cache dirs and delete those. -func (fs *Files) doCacheSweep(ctx context.Context) { - log := lg.FromContext(ctx).With(lga.Dir, fs.cacheDir) +func (fs *Files) doCacheSweep() { + ctx, cancelFn := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancelFn() + + log := fs.log.With(lga.Dir, fs.cacheDir) log.Debug("Sweep cache dir: acquiring config lock") if unlock, err := fs.cfgLockFn(ctx); err != nil { diff --git a/libsq/source/files.go b/libsq/source/files.go index d24724b3e..9e6ad1262 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -64,6 +64,8 @@ type Files struct { // to multiple readers via Files.newReader. fscache *fscache.FSCache + fscacheDir string + // fscacheEntryMetas contains metadata about fscache entries. // Entries are added by Files.addStdin, and consumed by // Files.Filesize. @@ -76,11 +78,11 @@ type Files struct { // NewFiles returns a new Files instance. If cleanFscache is true, the fscache // is cleaned on Files.Close. -func NewFiles(ctx context.Context, +func NewFiles( + ctx context.Context, optReg *options.Registry, cfgLock lockfile.LockFunc, tmpDir, cacheDir string, - cleanFscache bool, ) (*Files, error) { log := lg.FromContext(ctx) log.Debug("Creating new Files instance", "tmp_dir", tmpDir, "cache_dir", cacheDir) @@ -105,33 +107,16 @@ func NewFiles(ctx context.Context, // on cleanup (unless something bad happens and sq doesn't // get a chance to clean up). But, why take the chance; we'll just give // fcache a unique dir each time. - fscacheTmpDir := filepath.Join( - cacheDir, - "fscache", - strconv.Itoa(os.Getpid())+"_"+checksum.Rand(), - ) + fs.fscacheDir = filepath.Join(cacheDir, "fscache", strconv.Itoa(os.Getpid())+"_"+checksum.Rand()) - if err := ioz.RequireDir(fscacheTmpDir); err != nil { + if err := ioz.RequireDir(fs.fscacheDir); err != nil { return nil, errz.Err(err) } - if cleanFscache { - fs.clnup.AddE(func() error { - return errz.Wrap(os.RemoveAll(fscacheTmpDir), "remove fscache dir") - }) - } - var err error - if fs.fscache, err = fscache.New(fscacheTmpDir, os.ModePerm, time.Hour); err != nil { + if fs.fscache, err = fscache.New(fs.fscacheDir, os.ModePerm, time.Hour); err != nil { return nil, errz.Err(err) } - fs.clnup.AddE(fs.fscache.Clean) - - fs.clnup.AddE(func() error { - return errz.Wrap(os.RemoveAll(fs.tempDir), "remove files temp dir") - }) - - fs.clnup.Add(func() { fs.doCacheSweep(ctx) }) return fs, nil } @@ -448,11 +433,14 @@ func (fs *Files) Close() error { fs.log.Debug("Files.Close: waiting for goroutines to complete") fs.fillerWgs.Wait() - // TODO: Should delete the tmp dir - // TODO: Should sweep the cache + fs.log.Debug("Files.Close: executing cleanup", lga.Count, fs.clnup.Len()) + err := fs.clnup.Run() + err = errz.Append(err, fs.fscache.Clean()) + err = errz.Append(err, errz.Wrap(os.RemoveAll(fs.fscacheDir), "remove fscache dir")) + err = errz.Append(err, errz.Wrap(os.RemoveAll(fs.tempDir), "remove files temp dir")) + fs.doCacheSweep() - fs.log.Debug("Files.Close: executing clean funcs", lga.Count, fs.clnup.Len()) - return fs.clnup.Run() + return err } // CleanupE adds fn to the cleanup sequence invoked by fs.Close. diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index 729ca2324..e358e2144 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -56,14 +56,7 @@ func TestFiles_Type(t *testing.T) { t.Run(tu.Name(source.RedactLocation(tc.loc)), func(t *testing.T) { ctx := lg.NewContext(context.Background(), lgt.New(t)) - fs, err := source.NewFiles( - ctx, - nil, - testh.TempLockFunc(t), - tu.TempDir(t, true), - tu.CacheDir(t, true), - true, - ) + fs, err := source.NewFiles(ctx, nil, testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true)) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -105,14 +98,7 @@ func TestFiles_DetectType(t *testing.T) { t.Run(filepath.Base(tc.loc), func(t *testing.T) { ctx := lg.NewContext(context.Background(), lgt.New(t)) - fs, err := source.NewFiles( - ctx, - nil, - testh.TempLockFunc(t), - tu.TempDir(t, true), - tu.CacheDir(t, true), - true, - ) + fs, err := source.NewFiles(ctx, nil, testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true)) require.NoError(t, err) fs.AddDriverDetectors(testh.DriverDetectors()...) @@ -172,14 +158,7 @@ func TestFiles_NewReader(t *testing.T) { Location: proj.Abs(fpath), } - fs, err := source.NewFiles( - ctx, - nil, - testh.TempLockFunc(t), - tu.TempDir(t, true), - tu.CacheDir(t, true), - true, - ) + fs, err := source.NewFiles(ctx, nil, testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true)) require.NoError(t, err) g := &errgroup.Group{} diff --git a/testh/testh.go b/testh/testh.go index 8a999211c..aa67f0483 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -171,7 +171,6 @@ func (h *Helper) init() { TempLockFunc(h.T), tu.TempDir(h.T, false), tu.TempDir(h.T, false), - true, ) require.NoError(h.T, err) From 0e5ce23a90ccececcb67be755466446e436ce720 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 08:11:34 -0700 Subject: [PATCH 175/195] Minor test cleanup --- CHANGELOG.md | 6 ++- libsq/source/files_test.go | 95 +++++++++++++++++++++----------------- testh/testh.go | 7 +-- 3 files changed, 59 insertions(+), 49 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c141d2de..884f7d301 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ Breaking changes are annotated with ☢️, and alpha/beta features with 🐥. ## Upcoming +This is a significant release, focused on improving i/o, responsiveness, +and performance. The headline feature is caching of ingested data for +document sources such as CSV or Excel. + ### Added - Long-running operations (such as data ingestion, or file download) now result @@ -21,7 +25,7 @@ Breaking changes are annotated with ☢️, and alpha/beta features with 🐥. now make use of an ingest cache DB. Previously, ingestion of document source data occurred on each `sq` command. It is now a one-time cost; subsequent use of the document source utilizes the cache DB. If the source document changes, the ingest cache DB is invalidated and - ingested again. This is a massively improved experience for large document sources. + ingested again. This is a significantly improved experience for large document sources. - There's several new commands to interact with the cache. - [`sq cache enable`](https://sq.io/docs/cmd/cache_enable) and [`sq cache disable`](https://sq.io/docs/cmd/cache_disable) control cache usage. diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index e358e2144..246242af1 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -7,6 +7,8 @@ import ( "path/filepath" "testing" + "github.com/neilotoole/sq/drivers/json" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" @@ -29,49 +31,6 @@ import ( "github.com/neilotoole/sq/testh/tu" ) -func TestFiles_Type(t *testing.T) { - testCases := []struct { - loc string - wantType drivertype.Type - wantErr bool - }{ - {loc: proj.Expand("sqlite3://${SQ_ROOT}/drivers/sqlite3/testdata/sakila.db"), wantType: sqlite3.Type}, - {loc: proj.Abs(sakila.PathSL3), wantType: sqlite3.Type}, - {loc: proj.Abs("drivers/sqlite3/testdata/sakila_db"), wantType: sqlite3.Type}, - {loc: "sqlserver://sakila:p_ssW0rd@localhost?database=sakila", wantType: sqlserver.Type}, - {loc: "postgres://sakila:p_ssW0rd@localhost/sakila?sslmode=disable", wantType: postgres.Type}, - {loc: "mysql://sakila:p_ssW0rd@localhost/sakila", wantType: mysql.Type}, - {loc: proj.Abs(testsrc.PathXLSXTestHeader), wantType: xlsx.Type}, - {loc: proj.Abs("drivers/xlsx/testdata/test_header_xlsx"), wantType: xlsx.Type}, - {loc: sakila.URLSubsetXLSX, wantType: xlsx.Type}, - {loc: proj.Abs(sakila.PathCSVActor), wantType: csv.TypeCSV}, - {loc: proj.Abs("drivers/csv/testdata/person_csv"), wantType: csv.TypeCSV}, - {loc: sakila.URLActorCSV, wantType: csv.TypeCSV}, - {loc: proj.Abs("drivers/csv/testdata/person_tsv"), wantType: csv.TypeTSV}, - {loc: proj.Abs(sakila.PathTSVActor), wantType: csv.TypeTSV}, - } - - for _, tc := range testCases { - tc := tc - t.Run(tu.Name(source.RedactLocation(tc.loc)), func(t *testing.T) { - ctx := lg.NewContext(context.Background(), lgt.New(t)) - - fs, err := source.NewFiles(ctx, nil, testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true)) - require.NoError(t, err) - fs.AddDriverDetectors(testh.DriverDetectors()...) - - gotType, gotErr := fs.DriverType(context.Background(), "@test_"+stringz.Uniq8(), tc.loc) - if tc.wantErr { - require.Error(t, gotErr) - return - } - - require.NoError(t, gotErr) - require.Equal(t, tc.wantType, gotType) - }) - } -} - func TestFiles_DetectType(t *testing.T) { testCases := []struct { loc string @@ -90,6 +49,10 @@ func TestFiles_DetectType(t *testing.T) { {loc: proj.Abs("drivers/csv/testdata/person.tsv"), wantType: csv.TypeTSV, wantOK: true}, {loc: proj.Abs("drivers/csv/testdata/person_noheader.tsv"), wantType: csv.TypeTSV, wantOK: true}, {loc: proj.Abs("drivers/csv/testdata/person_tsv"), wantType: csv.TypeTSV, wantOK: true}, + {loc: proj.Abs("drivers/csv/testdata/person_tsv"), wantType: csv.TypeTSV, wantOK: true}, + {loc: proj.Abs("drivers/json/testdata/actor.json"), wantType: json.TypeJSON, wantOK: true}, + {loc: proj.Abs("drivers/json/testdata/actor.jsona"), wantType: json.TypeJSONA, wantOK: true}, + {loc: proj.Abs("drivers/json/testdata/actor.jsonl"), wantType: json.TypeJSONL, wantOK: true}, {loc: proj.Abs("README.md"), wantType: drivertype.None, wantOK: false}, } @@ -115,6 +78,52 @@ func TestFiles_DetectType(t *testing.T) { } } +func TestFiles_DriverType(t *testing.T) { + testCases := []struct { + loc string + wantType drivertype.Type + wantErr bool + }{ + {loc: proj.Expand("sqlite3://${SQ_ROOT}/drivers/sqlite3/testdata/sakila.db"), wantType: sqlite3.Type}, + {loc: proj.Abs(sakila.PathSL3), wantType: sqlite3.Type}, + {loc: proj.Abs("drivers/sqlite3/testdata/sakila_db"), wantType: sqlite3.Type}, + {loc: "sqlserver://sakila:p_ssW0rd@localhost?database=sakila", wantType: sqlserver.Type}, + {loc: "postgres://sakila:p_ssW0rd@localhost/sakila", wantType: postgres.Type}, + {loc: "mysql://sakila:p_ssW0rd@localhost/sakila", wantType: mysql.Type}, + {loc: proj.Abs(testsrc.PathXLSXTestHeader), wantType: xlsx.Type}, + {loc: proj.Abs("drivers/xlsx/testdata/test_header_xlsx"), wantType: xlsx.Type}, + {loc: sakila.URLSubsetXLSX, wantType: xlsx.Type}, + {loc: proj.Abs(sakila.PathCSVActor), wantType: csv.TypeCSV}, + {loc: proj.Abs("drivers/csv/testdata/person_csv"), wantType: csv.TypeCSV}, + {loc: sakila.URLActorCSV, wantType: csv.TypeCSV}, + {loc: proj.Abs(sakila.PathTSVActor), wantType: csv.TypeTSV}, + {loc: proj.Abs("drivers/csv/testdata/person_tsv"), wantType: csv.TypeTSV}, + {loc: proj.Abs("drivers/json/testdata/actor.json"), wantType: json.TypeJSON}, + {loc: proj.Abs("drivers/json/testdata/actor.jsona"), wantType: json.TypeJSONA}, + {loc: proj.Abs("drivers/json/testdata/actor.jsonl"), wantType: json.TypeJSONL}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tu.Name(source.RedactLocation(tc.loc)), func(t *testing.T) { + ctx := lg.NewContext(context.Background(), lgt.New(t)) + + fs, err := source.NewFiles(ctx, nil, testh.TempLockFunc(t), tu.TempDir(t, true), tu.CacheDir(t, true)) + require.NoError(t, err) + fs.AddDriverDetectors(testh.DriverDetectors()...) + + gotType, gotErr := fs.DriverType(context.Background(), "@test_"+stringz.Uniq8(), tc.loc) + if tc.wantErr { + require.Error(t, gotErr) + return + } + + require.NoError(t, gotErr) + require.Equal(t, tc.wantType, gotType) + }) + } +} + func TestDetectMagicNumber(t *testing.T) { testCases := []struct { loc string diff --git a/testh/testh.go b/testh/testh.go index aa67f0483..39aeca9bb 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -181,10 +181,7 @@ func (h *Helper) init() { }) h.grips = driver.NewGrips(h.registry, h.files, sqlite3.NewScratchSource) - // h.Cleanup.AddC(h.grips) - h.Cleanup.AddE(func() error { - return h.grips.Close() - }) + h.Cleanup.AddC(h.grips) h.registry.AddProvider(sqlite3.Type, &sqlite3.Provider{Log: log}) h.registry.AddProvider(postgres.Type, &postgres.Provider{Log: log}) @@ -878,7 +875,7 @@ func DriverDetectors() []source.DriverDetectFunc { xlsx.DetectXLSX, csv.DetectCSV, csv.DetectTSV, - // json.DetectJSON(1000), // FIXME: enable DetectJSON when it's ready + json.DetectJSON(1000), json.DetectJSONA(1000), json.DetectJSONL(1000), } From 1c3f6a89431dca5315f211b9f392c1135b97ba49 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 08:20:13 -0700 Subject: [PATCH 176/195] Grips.OpenIngest now more resilient to corrupt cache --- libsq/driver/grips.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 187181592..9237be198 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -178,6 +178,10 @@ func (gs *Grips) openNewCacheGrip(ctx context.Context, src *source.Source) (grip const msgRemoveScratch = "Remove cache db" log := lg.FromContext(ctx) + if err = gs.files.CacheClearSource(ctx, src, false); err != nil { + return nil, nil, err + } + cacheDir, srcCacheDBFilepath, _, err := gs.files.CachePaths(src) if err != nil { return nil, nil, err From a8d3f499f79f31056839b435d86fba7be6af70d6 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 09:09:51 -0700 Subject: [PATCH 177/195] Fixed issues with cmd_add_test.go --- cli/cmd_add_test.go | 145 +++++++++++++++++------------- cli/config/yamlstore/upgrade.go | 2 +- cli/config/yamlstore/yamlstore.go | 7 +- 3 files changed, 89 insertions(+), 65 deletions(-) diff --git a/cli/cmd_add_test.go b/cli/cmd_add_test.go index ce2998dd8..c07a0b461 100644 --- a/cli/cmd_add_test.go +++ b/cli/cmd_add_test.go @@ -5,18 +5,20 @@ import ( "path/filepath" "testing" - "github.com/stretchr/testify/require" - - "github.com/neilotoole/sq/cli/testrun" "github.com/neilotoole/sq/drivers/csv" "github.com/neilotoole/sq/drivers/mysql" "github.com/neilotoole/sq/drivers/postgres" "github.com/neilotoole/sq/drivers/sqlite3" "github.com/neilotoole/sq/drivers/sqlserver" - "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/drivers/xlsx" "github.com/neilotoole/sq/libsq/source" - "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/testh" + + "github.com/stretchr/testify/require" + + "github.com/neilotoole/sq/cli/testrun" + "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/tu" @@ -35,54 +37,62 @@ func TestCmdAdd(t *testing.T) { wantRows: sakila.TblActorCount, wantCols: len(sakila.TblActorCols()), } - - th := testh.New(t) + _ = actorDataQuery testCases := []struct { - loc string // first arg to "add" cmd - driver string // --driver flag - handle string // --handle flag - wantHandle string - wantType drivertype.Type - wantOptions options.Options - wantErr bool - query *query + // Set only one of loc, or locFromHandle, to create + // the first arg to "add" cmd. + // + // loc, when set, will be used directly. + loc string + // locFromHandle, when set, gets the location from the + // config source with the given handle. + locFromHandle string + + driver string // --driver flag + handle string // --handle flag + wantHandle string + wantType drivertype.Type + wantOptions options.Options + wantAddErr bool + wantQueryErr bool + query *query }{ { - loc: "", - wantErr: true, + loc: "", + wantAddErr: true, }, { - loc: " ", - wantErr: true, + loc: " ", + wantAddErr: true, }, { - loc: "/", - wantErr: true, + loc: "/", + wantAddErr: true, }, { - loc: "../../", - wantErr: true, + loc: "../../", + wantAddErr: true, }, { - loc: "does/not/exist", - wantErr: true, + loc: "does/not/exist", + wantAddErr: true, }, { - loc: "_", - wantErr: true, + loc: "_", + wantAddErr: true, }, { - loc: ".", - wantErr: true, + loc: ".", + wantAddErr: true, }, { - loc: "/", - wantErr: true, + loc: "/", + wantAddErr: true, }, { - loc: "../does/not/exist.csv", - wantErr: true, + loc: "../does/not/exist.csv", + wantAddErr: true, }, { loc: proj.Rel(sakila.PathCSVActor), @@ -109,9 +119,14 @@ func TestCmdAdd(t *testing.T) { wantType: csv.TypeCSV, }, { - loc: proj.Abs(sakila.PathCSVActor), - driver: "xlsx", - wantErr: true, + loc: proj.Abs(sakila.PathCSVActor), + driver: "xlsx", + wantHandle: "@actor", + wantType: xlsx.Type, + // It's legal to add a CSV file with the xlsx driver. + wantAddErr: false, + // But it should fail when we try to query it. + wantQueryErr: true, }, { loc: proj.Rel(sakila.PathTSVActor), @@ -148,49 +163,53 @@ func TestCmdAdd(t *testing.T) { wantType: sqlite3.Type, }, { - // without scheme, relative path - loc: th.Source(sakila.Pg).Location, - wantHandle: "@sakila", - wantType: postgres.Type, + locFromHandle: sakila.Pg, + wantHandle: "@sakila", + wantType: postgres.Type, }, { - loc: th.Source(sakila.MS).Location, - wantHandle: "@sakila", - wantType: sqlserver.Type, + locFromHandle: sakila.MS, + wantHandle: "@sakila", + wantType: sqlserver.Type, }, { - loc: th.Source(sakila.My).Location, - wantHandle: "@sakila", - wantType: mysql.Type, + locFromHandle: sakila.My, + wantHandle: "@sakila", + wantType: mysql.Type, }, { - loc: proj.Abs(sakila.PathCSVActor), - handle: source.StdinHandle, // reserved handle - wantErr: true, + loc: proj.Abs(sakila.PathCSVActor), + handle: source.StdinHandle, // reserved handle + wantAddErr: true, }, { - loc: proj.Abs(sakila.PathCSVActor), - handle: source.ActiveHandle, // reserved handle - wantErr: true, + loc: proj.Abs(sakila.PathCSVActor), + handle: source.ActiveHandle, // reserved handle + wantAddErr: true, }, { - loc: proj.Abs(sakila.PathCSVActor), - handle: source.ScratchHandle, // reserved handle - wantErr: true, + loc: proj.Abs(sakila.PathCSVActor), + handle: source.ScratchHandle, // reserved handle + wantAddErr: true, }, { - loc: proj.Abs(sakila.PathCSVActor), - handle: source.JoinHandle, // reserved handle - wantErr: true, + loc: proj.Abs(sakila.PathCSVActor), + handle: source.JoinHandle, // reserved handle + wantAddErr: true, }, } for i, tc := range testCases { tc := tc - t.Run(tu.Name(i, tc.wantHandle, tc.loc, tc.driver), func(t *testing.T) { + t.Run(tu.Name(i, tc.wantHandle, tc.loc, tc.locFromHandle, tc.driver), func(t *testing.T) { + if tc.locFromHandle != "" { + th := testh.New(t) + tc.loc = th.Source(tc.locFromHandle).Location + } + args := []string{"add", tc.loc} if tc.handle != "" { args = append(args, "--handle="+tc.handle) @@ -199,9 +218,9 @@ func TestCmdAdd(t *testing.T) { args = append(args, "--driver="+tc.driver) } - tr := testrun.New(th.Context, t, nil) + tr := testrun.New(context.Background(), t, nil) err := tr.Exec(args...) - if tc.wantErr { + if tc.wantAddErr { require.Error(t, err) return } @@ -220,6 +239,10 @@ func TestCmdAdd(t *testing.T) { } err = tr.Reset().Exec(tc.query.q, "--json") + if tc.wantQueryErr { + require.Error(t, err) + return + } require.NoError(t, err) var results []map[string]any tr.Bind(&results) diff --git a/cli/config/yamlstore/upgrade.go b/cli/config/yamlstore/upgrade.go index f955903d8..d492700f6 100644 --- a/cli/config/yamlstore/upgrade.go +++ b/cli/config/yamlstore/upgrade.go @@ -60,7 +60,7 @@ func (fs *Store) doUpgrade(ctx context.Context, startVersion, targetVersion stri log.Debug("Config upgrade step successful") } - if err = fs.write(data); err != nil { + if err = fs.write(ctx, data); err != nil { return nil, err } diff --git a/cli/config/yamlstore/yamlstore.go b/cli/config/yamlstore/yamlstore.go index 40b371364..5a192aa99 100644 --- a/cli/config/yamlstore/yamlstore.go +++ b/cli/config/yamlstore/yamlstore.go @@ -178,7 +178,7 @@ func (fs *Store) doLoad(ctx context.Context) (*config.Config, error) { } // Save writes config to disk. It implements Store. -func (fs *Store) Save(_ context.Context, cfg *config.Config) error { +func (fs *Store) Save(ctx context.Context, cfg *config.Config) error { if fs == nil { return errz.New("config file store is nil") } @@ -192,11 +192,11 @@ func (fs *Store) Save(_ context.Context, cfg *config.Config) error { return err } - return fs.write(data) + return fs.write(ctx, data) } // Write writes the config bytes to disk. -func (fs *Store) write(data []byte) error { +func (fs *Store) write(ctx context.Context, data []byte) error { // It's possible that the parent dir of fs.Path doesn't exist. if err := ioz.RequireDir(filepath.Dir(fs.Path)); err != nil { return errz.Wrapf(err, "failed to make parent dir of config file: %s", filepath.Dir(fs.Path)) @@ -206,6 +206,7 @@ func (fs *Store) write(data []byte) error { return errz.Wrap(err, "failed to save config file") } + lg.FromContext(ctx).Info("Wrote config file", lga.Path, fs.Path) return nil } From e037541223a8380a9657cac44e92f939141a812f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 10:12:50 -0700 Subject: [PATCH 178/195] Test refactoring --- libsq/ast/handle.go | 31 ++++++++++++++++++++++++++++ libsq/ast/handle_test.go | 32 +++++++++++++++++++++++++++++ libsq/query_test.go | 19 ++++++++++++++++- testh/testh.go | 44 +++++++++++++++++++++++++++++++++++++++- 4 files changed, 124 insertions(+), 2 deletions(-) create mode 100644 libsq/ast/handle_test.go diff --git a/libsq/ast/handle.go b/libsq/ast/handle.go index 8d3804bd4..77e9744a8 100644 --- a/libsq/ast/handle.go +++ b/libsq/ast/handle.go @@ -1,8 +1,11 @@ package ast import ( + "slices" + "github.com/neilotoole/sq/libsq/ast/internal/slq" "github.com/neilotoole/sq/libsq/core/tablefq" + "github.com/samber/lo" ) // HandleNode models a source handle such as "@sakila". @@ -47,3 +50,31 @@ func (v *parseTreeVisitor) VisitHandleTable(ctx *slq.HandleTableContext) any { return v.cur.AddChild(node) } + +// ExtractHandles returns a sorted slice of all handles mentioned +// in the AST. Duplicate mentions are removed. +func ExtractHandles(ast *AST) []string { + var handles []string + handleNodes := FindNodes[*HandleNode](ast) + for _, n := range handleNodes { + handles = append(handles, n.Handle()) + } + + joinNodes := FindNodes[*JoinNode](ast) + for _, n := range joinNodes { + if n != nil && n.Table().Handle() != "" { + handles = append(handles, n.Table().Handle()) + } + } + + tblSelNodes := FindNodes[*TblSelectorNode](ast) + for _, n := range tblSelNodes { + if n.Handle() != "" { + handles = append(handles, n.Handle()) + } + } + + handles = lo.Uniq(handles) + slices.Sort(handles) + return handles +} diff --git a/libsq/ast/handle_test.go b/libsq/ast/handle_test.go new file mode 100644 index 000000000..871c8bb47 --- /dev/null +++ b/libsq/ast/handle_test.go @@ -0,0 +1,32 @@ +package ast + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExtractHandles(t *testing.T) { + testCases := []struct { + input string + want []string + }{ + { + input: "@sakila | .actor", + want: []string{"@sakila"}, + }, + { + input: "@sakila_pg | .actor | join(@sakila_ms.film_actor, .actor_id)", + want: []string{"@sakila_ms", "@sakila_pg"}, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + a := mustParse(t, tc.input) + got := ExtractHandles(a) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/libsq/query_test.go b/libsq/query_test.go index 1a4dbb0d0..026a79e8a 100644 --- a/libsq/query_test.go +++ b/libsq/query_test.go @@ -141,7 +141,17 @@ func execQueryTestCase(t *testing.T, tc queryTestCase) { func doExecQueryTestCase(t *testing.T, tc queryTestCase) { t.Helper() - coll := testh.New(t).NewCollection(sakila.SQLLatest()...) + th := testh.New(t) + var handles []string + for _, handle := range sakila.SQLLatest() { + if th.SourceConfigured(handle) { + handles = append(handles, handle) + } else { + t.Logf("Skipping because source %s is not configured", handle) + } + } + + coll := th.NewCollection(handles...) for _, src := range coll.Sources() { src := src @@ -153,6 +163,13 @@ func doExecQueryTestCase(t *testing.T, tc queryTestCase) { t.Helper() in := strings.Replace(tc.in, "@sakila", src.Handle, 1) + for _, handle := range testh.ExtractHandlesFromQuery(t, tc.in, false) { + if !testh.New(t).SourceConfigured(handle) { + t.Skipf("Skipping because source %s is not configured", handle) + return + } + } + t.Logf("QUERY:\n\n%s\n\n", in) want := tc.wantSQL if overrideWant, ok := tc.override[src.Type]; ok { diff --git a/testh/testh.go b/testh/testh.go index 39aeca9bb..5c1da9b14 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -359,6 +359,29 @@ func (h *Helper) Source(handle string) *source.Source { return src } +// SourceConfigured returns true if the source is configured. Note +// that Helper.Source skips the test if the source is not configured: that +// is to say, if the source location requires population via an envar, and +// the envar is not set. For example, for the PostgreSQL source @sakila_pg12, +// the envar SQ_TEST_SRC__SAKILA_PG12 is required. SourceConfigured tests +// if that envar is set. +func (h *Helper) SourceConfigured(handle string) bool { + h.mu.Lock() + defer h.mu.Unlock() + + if !stringz.InSlice(sakila.SQLAllExternal(), handle) { + // Non-SQL and SQLite sources are always available. + return true + } + + handleEnvar := "SQ_TEST_SRC__" + strings.ToUpper(strings.TrimPrefix(handle, "@")) + if envar, ok := os.LookupEnv(handleEnvar); !ok || strings.TrimSpace(envar) == "" { + return false + } + + return true +} + // NewCollection is a convenience function for building a // new *source.Collection incorporating the supplied handles. See // Helper.Source for more on the behavior. @@ -892,13 +915,15 @@ func SetBuildVersion(t testing.TB, vers string) { }) } +// TempLockfile returns a lockfile.Lockfile that uses a temp file. func TempLockfile(t testing.TB) lockfile.Lockfile { return lockfile.Lockfile(tu.TempFile(t, "pid.lock", false)) } +// TempLockFunc returns a lockfile.LockFunc that uses a temp file. func TempLockFunc(t testing.TB) lockfile.LockFunc { return func(ctx context.Context) (unlock func(), err error) { - lock := lockfile.Lockfile(tu.TempFile(t, "pid.lock", false)) + lock := TempLockfile(t) timeout := config.OptConfigLockTimeout.Default() if err = lock.Lock(ctx, timeout); err != nil { return nil, err @@ -911,3 +936,20 @@ func TempLockFunc(t testing.TB) lockfile.LockFunc { }, nil } } + +// ExtractHandlesFromQuery returns all handles mentioned in the query. +// If failOnErr is true, the test will fail on any parse error; otherwise, +// the test will log the error and return an empty slice. +func ExtractHandlesFromQuery(t testing.TB, query string, failOnErr bool) []string { + a, err := ast.Parse(lg.Discard(), query) + if err != nil { + if failOnErr { + require.NoError(t, err) + return nil + } + t.Logf("Failed to parse query: >> %s << : %v", query, err) + return []string{} + } + + return ast.ExtractHandles(a) +} From 1f8172317131ed16a61a733e3c3d693432c318ca Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 10:40:19 -0700 Subject: [PATCH 179/195] more test cleanup --- cli/cli_test.go | 37 +++++++++++++++++++++---------------- drivers/xlsx/ingest.go | 2 +- testh/testh.go | 2 +- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/cli/cli_test.go b/cli/cli_test.go index 543471a17..b084db27d 100644 --- a/cli/cli_test.go +++ b/cli/cli_test.go @@ -79,22 +79,27 @@ func TestSmoke(t *testing.T) { } } -func TestCreateTblTestBytes(t *testing.T) { - th, src, _, _, _ := testh.NewWith(t, sakila.Pg) - th.DiffDB(src) - - tblDef := sqlmodel.NewTableDef( - stringz.UniqTableName("test_bytes"), - []string{"col_name", "col_bytes"}, - []kind.Kind{kind.Text, kind.Bytes}, - ) - - fBytes := proj.ReadFile(fixt.GopherPath) - data := []any{fixt.GopherFilename, fBytes} - - require.Equal(t, int64(1), th.CreateTable(true, src, tblDef, data)) - t.Logf(src.Location) - th.DropTable(src, tablefq.From(tblDef.Name)) +func TestCreateTable_bytes(t *testing.T) { + for _, handle := range sakila.SQLLatest() { + handle := handle + t.Run(handle, func(t *testing.T) { + th, src, _, _, _ := testh.NewWith(t, handle) + th.DiffDB(src) + + tblDef := sqlmodel.NewTableDef( + stringz.UniqTableName("test_bytes"), + []string{"col_name", "col_bytes"}, + []kind.Kind{kind.Text, kind.Bytes}, + ) + + fBytes := proj.ReadFile(fixt.GopherPath) + data := []any{fixt.GopherFilename, fBytes} + + require.Equal(t, int64(1), th.CreateTable(true, src, tblDef, data)) + t.Logf(src.Location) + th.DropTable(src, tablefq.From(tblDef.Name)) + }) + } } // TestOutputRaw verifies that the raw output format works. diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 21e8ec06b..a7c3d474e 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -285,7 +285,7 @@ func buildSheetTables(ctx context.Context, srcIngestHeader *bool, sheets []*xShe sheetTbl, err := buildSheetTable(gCtx, srcIngestHeader, sheets[i]) if err != nil { - if errz.Has[*driver.EmptyDataError](err) { + if errz.Has[driver.EmptyDataError](err) { // If the sheet has no data, we log it and skip it. lg.FromContext(ctx).Warn("Excel sheet has no data", laSheet, sheets[i].name, diff --git a/testh/testh.go b/testh/testh.go index 5c1da9b14..4b66ebcbe 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -235,7 +235,7 @@ func (h *Helper) Add(src *source.Source) *source.Source { // This is a bit of a hack to ensure that internals are loaded: we // load a known source. The loading mechanism should be refactored // to not require this. - _ = h.Source(sakila.Pg) + _ = h.Source(sakila.SL3) h.mu.Lock() defer h.mu.Unlock() From 8d41f2329358dbaec46c1e722d6c95aedeec3355 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 11:12:37 -0700 Subject: [PATCH 180/195] Speeding up tests --- cli/cmd_inspect_test.go | 14 ++++++++++---- cli/testrun/testrun.go | 15 ++++++++++++--- testh/testh.go | 15 +++++++++------ 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/cli/cmd_inspect_test.go b/cli/cmd_inspect_test.go index bbe7b5796..9fe2840f2 100644 --- a/cli/cmd_inspect_test.go +++ b/cli/cmd_inspect_test.go @@ -7,6 +7,9 @@ import ( "os" "testing" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" + "github.com/samber/lo" "github.com/stretchr/testify/require" @@ -28,7 +31,7 @@ import ( // TestCmdInspect_json_yaml tests "sq inspect" for // the JSON and YAML formats. -func TestCmdInspect_json_yaml(t *testing.T) { +func TestCmdInspect_json_yaml(t *testing.T) { //nolint:tparallel tu.SkipShort(t, true) possibleTbls := append(sakila.AllTbls(), source.MonotableName) @@ -56,10 +59,13 @@ func TestCmdInspect_json_yaml(t *testing.T) { for _, tf := range testFormats { tf := tf t.Run(tf.format.String(), func(t *testing.T) { + t.Parallel() + for _, tc := range testCases { tc := tc t.Run(tc.handle, func(t *testing.T) { + t.Parallel() tu.SkipWindowsIf(t, tc.handle == sakila.XLSX, "XLSX too slow on windows workflow") th := testh.New(t) @@ -91,7 +97,7 @@ func TestCmdInspect_json_yaml(t *testing.T) { tblName := tblName t.Run(tblName, func(t *testing.T) { tu.SkipShort(t, true) - tr2 := testrun.New(th.Context, t, tr) + tr2 := testrun.New(lg.NewContext(th.Context, lgt.New(t)), t, tr) err := tr2.Exec("inspect", "."+tblName, fmt.Sprintf("--%s", tf.format)) require.NoError(t, err) tblMeta := &metadata.Table{} @@ -104,7 +110,7 @@ func TestCmdInspect_json_yaml(t *testing.T) { t.Run("inspect_overview", func(t *testing.T) { t.Logf("Test: sq inspect @src --overview") - tr2 := testrun.New(th.Context, t, tr) + tr2 := testrun.New(lg.NewContext(th.Context, lgt.New(t)), t, tr) err := tr2.Exec( "inspect", tc.handle, @@ -131,7 +137,7 @@ func TestCmdInspect_json_yaml(t *testing.T) { t.Run("inspect_dbprops", func(t *testing.T) { t.Logf("Test: sq inspect @src --dbprops") - tr2 := testrun.New(th.Context, t, tr) + tr2 := testrun.New(lg.NewContext(th.Context, lgt.New(t)), t, tr) err := tr2.Exec( "inspect", tc.handle, diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index 465933587..eea227d60 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -57,18 +57,21 @@ func New(ctx context.Context, t testing.TB, from *TestRun) *TestRun { } if !lg.InContext(ctx) { + // FIXME: Get rid of this abomination ctx = lg.NewContext(ctx, lgt.New(t)) } tr := &TestRun{T: t, Context: ctx, mu: &sync.Mutex{}} var cfgStore config.Store + var cacheDir string if from != nil { cfgStore = from.Run.ConfigStore + cacheDir = from.Run.Files.CacheDir() tr.hushOutput = from.hushOutput } - tr.Run, tr.Out, tr.ErrOut = newRun(ctx, t, cfgStore) + tr.Run, tr.Out, tr.ErrOut = newRun(ctx, t, cfgStore, cacheDir) tr.Context = options.NewContext(ctx, tr.Run.Config.Options) return tr } @@ -79,7 +82,9 @@ func New(ctx context.Context, t testing.TB, from *TestRun) *TestRun { // these buffers can be written to t.Log() if desired. // // If cfgStore is nil, a new one is created in a temp dir. -func newRun(ctx context.Context, t testing.TB, cfgStore config.Store) (ru *run.Run, out, errOut *bytes.Buffer) { +func newRun(ctx context.Context, t testing.TB, + cfgStore config.Store, cacheDir string, +) (ru *run.Run, out, errOut *bytes.Buffer) { out = &bytes.Buffer{} errOut = &bytes.Buffer{} @@ -115,12 +120,16 @@ func newRun(ctx context.Context, t testing.TB, cfgStore config.Store) (ru *run.R // The Files instance needs unique dirs for temp and cache because // the test runs may execute in parallel inside the same test binary // process, thus breaking the pid-based lockfile mechanism. + if cacheDir == "" { + cacheDir = tu.CacheDir(t, false) + } + ru.Files, err = source.NewFiles( ctx, ru.OptionsRegistry, testh.TempLockFunc(t), tu.TempDir(t, false), - tu.CacheDir(t, false), + cacheDir, ) require.NoError(t, err) diff --git a/testh/testh.go b/testh/testh.go index 4b66ebcbe..4e4b889fb 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -116,10 +116,6 @@ func New(t testing.TB, opts ...Option) *Helper { dbOpenTimeout: defaultDBOpenTimeout, } - for _, opt := range opts { - opt(h) - } - ctx, cancelFn := context.WithCancel(context.Background()) h.cancelFn = cancelFn @@ -132,9 +128,16 @@ func New(t testing.TB, opts ...Option) *Helper { // // REVISIT: The above statement regarding pid-based locking may no longer // be applicable, as a new cache dir is created for each test run. - o := options.Options{driver.OptIngestCache.Key(): false} - h.Context = options.NewContext(h.Context, o) + // + // FIXME: Add an option to set config value + // o := options.Options{driver.OptIngestCache.Key(): false} + // h.Context = options.NewContext(h.Context, o) t.Cleanup(h.Close) + + for _, opt := range opts { + opt(h) + } + return h } From a724750bae19d3577b0448d1e9e7acfee99248e4 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 11:47:55 -0700 Subject: [PATCH 181/195] Speeding up tests, again --- cli/cli.go | 2 +- cli/cmd_inspect_test.go | 3 ++- cli/testrun/testrun.go | 3 ++- testh/testh.go | 26 ++++++++++++++------------ 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/cli/cli.go b/cli/cli.go index eb7046d13..8225c555f 100644 --- a/cli/cli.go +++ b/cli/cli.go @@ -85,7 +85,7 @@ func ExecuteWith(ctx context.Context, ru *run.Run, args []string) error { _ = ru.LogCloser() } }() - ctx = options.NewContext(ctx, ru.Config.Options) + ctx = options.NewContext(ctx, options.Merge(options.FromContext(ctx), ru.Config.Options)) log := lg.FromContext(ctx) log.Info("EXECUTE", "args", strings.Join(args, " ")) log.Info("Build info", "build", buildinfo.Get()) diff --git a/cli/cmd_inspect_test.go b/cli/cmd_inspect_test.go index 9fe2840f2..8382bbc87 100644 --- a/cli/cmd_inspect_test.go +++ b/cli/cmd_inspect_test.go @@ -68,7 +68,7 @@ func TestCmdInspect_json_yaml(t *testing.T) { //nolint:tparallel t.Parallel() tu.SkipWindowsIf(t, tc.handle == sakila.XLSX, "XLSX too slow on windows workflow") - th := testh.New(t) + th := testh.New(t, testh.OptCaching(true)) src := th.Source(tc.handle) tr := testrun.New(th.Context, t, nil).Hush().Add(*src) @@ -96,6 +96,7 @@ func TestCmdInspect_json_yaml(t *testing.T) { //nolint:tparallel for _, tblName := range gotTableNames { tblName := tblName t.Run(tblName, func(t *testing.T) { + // t.Parallel() tu.SkipShort(t, true) tr2 := testrun.New(lg.NewContext(th.Context, lgt.New(t)), t, tr) err := tr2.Exec("inspect", "."+tblName, fmt.Sprintf("--%s", tf.format)) diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index eea227d60..fd1f58824 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -72,7 +72,8 @@ func New(ctx context.Context, t testing.TB, from *TestRun) *TestRun { } tr.Run, tr.Out, tr.ErrOut = newRun(ctx, t, cfgStore, cacheDir) - tr.Context = options.NewContext(ctx, tr.Run.Config.Options) + o := options.Merge(options.FromContext(ctx), tr.Run.Config.Options) + tr.Context = options.NewContext(ctx, o) return tr } diff --git a/testh/testh.go b/testh/testh.go index 4e4b889fb..b1c89bc58 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -81,6 +81,20 @@ func OptLongOpen() Option { } } +// OptCaching enables or disables ingest caching. +func OptCaching(enable bool) Option { + return func(h *Helper) { + o := options.FromContext(h.Context) + if o == nil { + o = options.Options{driver.OptIngestCache.Key(): enable} + h.Context = options.NewContext(h.Context, o) + return + } + + o[driver.OptIngestCache.Key()] = enable + } +} + // Helper encapsulates a test helper session. type Helper struct { mu sync.Mutex @@ -120,18 +134,6 @@ func New(t testing.TB, opts ...Option) *Helper { h.cancelFn = cancelFn h.Context = lg.NewContext(ctx, h.Log) - - // Disable caching in tests, because there's all sorts of confounding - // situations with running tests in parallel with caching enabled, - // due to the fact that caching uses pid-based locking, and parallel tests - // share the same pid. - // - // REVISIT: The above statement regarding pid-based locking may no longer - // be applicable, as a new cache dir is created for each test run. - // - // FIXME: Add an option to set config value - // o := options.Options{driver.OptIngestCache.Key(): false} - // h.Context = options.NewContext(h.Context, o) t.Cleanup(h.Close) for _, opt := range opts { From 3147011d2b5fbdddbd64b3960a09dca96d1762ac Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 03:39:24 +0000 Subject: [PATCH 182/195] time testing --- drivers/csv/csv_test.go | 232 ++++++++++++++++++++---- drivers/csv/ingest.go | 1 + drivers/csv/testdata/test_datetime.tsv | 2 - drivers/csv/testdata/test_timestamp.tsv | 2 + 4 files changed, 197 insertions(+), 40 deletions(-) delete mode 100644 drivers/csv/testdata/test_datetime.tsv create mode 100644 drivers/csv/testdata/test_timestamp.tsv diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index cbf79da10..eeb9aa0e5 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -2,7 +2,12 @@ package csv_test import ( "context" + "fmt" + "github.com/neilotoole/sq/testh/fixt" + "golang.org/x/exp/maps" + "os" "path/filepath" + "slices" "testing" "time" @@ -216,40 +221,188 @@ func TestIngestDuplicateColumns(t *testing.T) { require.Equal(t, wantHeaders, data[0]) } -func TestDatetime(t *testing.T) { +func TestGenerateDatetimeVals(t *testing.T) { + canonicalTimeUTC := time.Unix(0, fixt.TimestampUnixNano1989).UTC() + _ = canonicalTimeUTC + + names := maps.Keys(timez.TimestampLayouts) + slices.Sort(names) + + for _, loc := range []*time.Location{time.UTC, timez.LosAngeles} { + fmt.Printf("\n\n%s\n\n", loc.String()) + tm := canonicalTimeUTC.In(loc) + + for _, name := range names { + layout := timez.TimestampLayouts[name] + fmt.Fprintf(os.Stdout, "%32s: %s\n", name, tm.Format(layout)) + } + } + + t.Logf("\n\n") +} + +func TestIngestTimestamp(t *testing.T) { t.Parallel() denver, err := time.LoadLocation("America/Denver") require.NoError(t, err) + _ = denver + lax, err := time.LoadLocation("America/Los_Angeles") + require.NoError(t, err) + _ = lax - wantDtNanoUTC := time.Date(1989, 11, 9, 15, 17, 59, 123456700, time.UTC) - wantDtMilliUTC := wantDtNanoUTC.Truncate(time.Millisecond) - wantDtSecUTC := wantDtNanoUTC.Truncate(time.Second) - wantDtMinUTC := wantDtNanoUTC.Truncate(time.Minute) - wantDtNanoMST := time.Date(1989, 11, 9, 15, 17, 59, 123456700, denver) - wantDtMilliMST := wantDtNanoMST.Truncate(time.Millisecond) - wantDtSecMST := wantDtNanoMST.Truncate(time.Second) - wantDtMinMST := wantDtNanoMST.Truncate(time.Minute) + _ = denver + + wantNanoUTC := time.Unix(0, fixt.TimestampUnixNano1989).UTC() + wantMilliUTC := wantNanoUTC.Truncate(time.Millisecond) + wantSecUTC := wantNanoUTC.Truncate(time.Second) + wantMinUTC := wantNanoUTC.Truncate(time.Minute) testCases := []struct { file string wantHeaders []string - wantKinds []kind.Kind - wantVals []any + wantVals []time.Time }{ { - file: "test_date", - wantHeaders: []string{"Long", "Short", "d-mmm-yy", "mm-dd-yy", "mmmm d, yyyy"}, - wantKinds: loz.Make(5, kind.Date), - wantVals: lo.ToAnySlice(loz.Make(5, - time.Date(1989, time.November, 9, 0, 0, 0, 0, time.UTC))), - }, - { - file: "test_time", - wantHeaders: []string{"time1", "time2", "time3", "time4", "time5", "time6"}, - wantKinds: loz.Make(6, kind.Time), - wantVals: []any{"15:17:00", "15:17:00", "15:17:00", "15:17:00", "15:17:00", "15:17:59"}, + file: "test_timestamp", + wantHeaders: []string{ + "ANSIC", + "DateHourMinute", + "DateHourMinuteSecond", + "ISO8601", + "ISO8601Z", + "RFC1123", + "RFC1123Z", + "RFC3339", + "RFC3339Nano", + "RFC3339NanoZ", + "RFC3339Z", + "RFC8222", + "RFC8222Z", + "RFC850", + "RubyDate", + "UnixDate", + }, + wantVals: []time.Time{ + wantSecUTC, // ANSIC + wantMinUTC, // DateHourMinute + wantSecUTC, // DateHourMinuteSecond + wantMilliUTC, // ISO8601 + wantMilliUTC, // ISO8601Z + wantSecUTC, // RFC1123 + wantSecUTC, // RFC1123Z + wantSecUTC, // RFC3339 + wantNanoUTC, // RFC3339Nano + wantNanoUTC, // RFC3339NanoZ + wantSecUTC, // RFC3339Z + wantMinUTC, // RFC8222 + wantMinUTC, // RFC8222Z + wantSecUTC, // RFC850 + wantSecUTC, // RubyDate + wantSecUTC, // UnixDate + }, }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.file, func(t *testing.T) { + th := testh.New(t, testh.OptLongOpen()) + src := &source.Source{ + Handle: "@tsv/" + tc.file, + Type: csv.TypeTSV, + Location: filepath.Join("testdata", tc.file+".tsv"), + } + src = th.Add(src) + + sink, err := th.QuerySLQ(src.Handle+".data", nil) + require.NoError(t, err) + + assert.Equal(t, tc.wantHeaders, sink.RecMeta.MungedNames()) + require.Len(t, sink.Recs, 1) + t.Log(sink.Recs[0]) + + for i, col := range sink.RecMeta.MungedNames() { + i, col := i, col + t.Run(col, func(t *testing.T) { + t.Logf("[%d] %s", i, col) + assert.Equal(t, kind.Datetime.String(), sink.RecMeta.Kinds()[i].String()) + if gotTime, ok := sink.Recs[0][i].(time.Time); ok { + // REVISIT: If it's a time value, we want to compare UTC times. + // This may actually be a bug. + wantTime := tc.wantVals[i] + t.Logf("wantTime: %s | %s | %d ", wantTime.Format(time.RFC3339Nano), wantTime.Location(), wantTime.Unix()) + t.Logf(" gotTime: %s | %s | %d ", gotTime.Format(time.RFC3339Nano), gotTime.Location(), gotTime.Unix()) + require.True(t, ok) + assert.Equal(t, wantTime.Unix(), gotTime.Unix()) + assert.Equal(t, wantTime.UTC(), gotTime.UTC()) + } else { + assert.EqualValues(t, tc.wantVals[i], sink.Recs[0][i]) + } + }) + } + }) + } +} + +func TestDatetime(t *testing.T) { + t.Parallel() + + denver, err := time.LoadLocation("America/Denver") + require.NoError(t, err) + _ = denver + lax, err := time.LoadLocation("America/Los_Angeles") + require.NoError(t, err) + _ = lax + + _ = denver + // 1989-11-09T15:17:59.1234567Z - RFC3339Nano + //wantDtNanoUTC := time.Date(1989, 11, 9, 15, 17, 59, 123456700, time.UTC) + wantNanoUTC := time.Unix(0, fixt.TimestampUnixNano1989).UTC() + //wantNanoUTC := timez.MustParse(time.RFC3339Nano, "1989-11-09T15:17:59.1234567Z") + + //got := wantDtNanoUTC.Format(time.RFC3339Nano) + //t.Logf(got) + + wantMilliUTC := wantNanoUTC.Truncate(time.Millisecond) + wantSecUTC := wantNanoUTC.Truncate(time.Second) + wantMinUTC := wantNanoUTC.Truncate(time.Minute) + + // 1989-11-09T15:17:59.1234567-07:00 - RFC3339Nano + //wantDtNanoMST := timez.MustParse(time.RFC3339Nano, "1989-11-09T15:17:59.1234567-07:00") + //wantDtNanoMST2 := time.Date(1989, 11, 9, 15, 17, 59, 123456700, denver) + // + //require.Equal(t, wantDtNanoMST.Unix(), wantDtNanoMST2.Unix()) + //require.Equal(t, wantDtNanoMST.UTC(), wantDtNanoMST2.UTC()) + ////got = wantDtNanoMST.Format(time.RFC3339Nano) + ////t.Log(got) + //wantDtMilliMST := wantDtNanoMST.Truncate(time.Millisecond) + //wantDtSecMST := wantDtNanoMST.Truncate(time.Second) + //wantDtMinMST := wantDtNanoMST.Truncate(time.Minute) + // + //t.Logf("wantDtSecMST: %s .... %d", wantDtSecMST.Format(time.RFC1123), wantDtSecMST.Unix()) + + // FIXME: repair this stuff + + testCases := []struct { + file string + wantHeaders []string + wantKinds []kind.Kind + wantVals []any + }{ + //{ + // file: "test_date", + // wantHeaders: []string{"Long", "Short", "d-mmm-yy", "mm-dd-yy", "mmmm d, yyyy"}, + // wantKinds: loz.Make(5, kind.Date), + // wantVals: lo.ToAnySlice(loz.Make(5, + // time.Date(1989, time.November, 9, 0, 0, 0, 0, time.UTC))), + //}, + //{ + // file: "test_time", + // wantHeaders: []string{"time1", "time2", "time3", "time4", "time5", "time6"}, + // wantKinds: loz.Make(6, kind.Time), + // wantVals: []any{"15:17:00", "15:17:00", "15:17:00", "15:17:00", "15:17:00", "15:17:59"}, + //}, { file: "test_datetime", wantHeaders: []string{ @@ -272,22 +425,22 @@ func TestDatetime(t *testing.T) { }, wantKinds: loz.Make(20, kind.Datetime), wantVals: lo.ToAnySlice([]time.Time{ - wantDtSecUTC, // ANSIC - wantDtMinUTC, // DateHourMinute - wantDtSecUTC, // DateHourMinuteSecond - wantDtMilliMST, // ISO8601 - wantDtMilliUTC, // ISO8601Z - wantDtSecMST, // RFC1123 - wantDtSecMST, // RFC1123Z - wantDtSecMST, // RFC3339 - wantDtNanoMST, // RFC3339Nano - wantDtNanoUTC, // RFC3339NanoZ - wantDtSecUTC, // RFC3339Z - wantDtMinMST, // RFC8222 - wantDtMinMST, // RFC8222Z - wantDtSecMST, // RFC850 - wantDtSecMST, // RubyDate - wantDtSecMST, // UnixDate + wantSecUTC, // ANSIC + wantMinUTC, // DateHourMinute + wantSecUTC, // DateHourMinuteSecond + wantMilliUTC, // ISO8601 + wantMilliUTC, // ISO8601Z + wantSecUTC, // RFC1123 + wantSecUTC, // RFC1123Z + wantSecUTC, // RFC3339 + wantNanoUTC, // RFC3339Nano + wantNanoUTC, // RFC3339NanoZ + wantSecUTC, // RFC3339Z + wantMinUTC, // RFC8222 + wantMinUTC, // RFC8222Z + wantSecUTC, // RFC850 + wantSecUTC, // RubyDate + wantSecUTC, // UnixDate }), }, } @@ -315,11 +468,14 @@ func TestDatetime(t *testing.T) { for i, col := range sink.RecMeta.MungedNames() { i, col := i, col t.Run(col, func(t *testing.T) { + t.Logf("[%d] %s", i, col) assert.Equal(t, tc.wantKinds[i].String(), sink.RecMeta.Kinds()[i].String()) if gotTime, ok := sink.Recs[0][i].(time.Time); ok { // REVISIT: If it's a time value, we want to compare UTC times. // This may actually be a bug. wantTime, ok := tc.wantVals[i].(time.Time) + t.Logf("wantTime: %s | %s | %d ", wantTime.Format(time.RFC3339Nano), wantTime.Location(), wantTime.Unix()) + t.Logf(" gotTime: %s | %s | %d ", gotTime.Format(time.RFC3339Nano), gotTime.Location(), gotTime.Unix()) require.True(t, ok) assert.Equal(t, wantTime.Unix(), gotTime.Unix()) assert.Equal(t, wantTime.UTC(), gotTime.UTC()) diff --git a/drivers/csv/ingest.go b/drivers/csv/ingest.go index fd2dcedfc..cd02a7ae6 100644 --- a/drivers/csv/ingest.go +++ b/drivers/csv/ingest.go @@ -69,6 +69,7 @@ func ingestCSV(ctx context.Context, src *source.Source, openFn source.FileOpenFu } cr := newCSVReader(rc, delim) + recs, err := readRecords(cr, driver.OptIngestSampleSize.Get(src.Options)) if err != nil { return err diff --git a/drivers/csv/testdata/test_datetime.tsv b/drivers/csv/testdata/test_datetime.tsv deleted file mode 100644 index 92b693036..000000000 --- a/drivers/csv/testdata/test_datetime.tsv +++ /dev/null @@ -1,2 +0,0 @@ -ANSIC DateHourMinute DateHourMinuteSecond ISO8601 ISO8601Z RFC1123 RFC1123Z RFC3339 RFC3339Nano RFC3339NanoZ RFC3339Z RFC8222 RFC8222Z RFC850 RubyDate UnixDate -Thu Nov 9 15:17:59 1989 1989-11-09 15:17 1989-11-09 15:17:59 1989-11-09T15:17:59.123-07:00 1989-11-09T15:17:59.123Z Thu, 09 Nov 1989 15:17:59 MST Thu, 09 Nov 1989 15:17:59 -0700 1989-11-09T15:17:59-07:00 1989-11-09T15:17:59.1234567-07:00 1989-11-09T15:17:59.1234567Z 1989-11-09T15:17:59Z 09 Nov 89 15:17 MST 09 Nov 89 15:17 -0700 Thursday, 09-Nov-89 15:17:59 MST Thu Nov 09 15:17:59 -0700 1989 Thu Nov 9 15:17:59 MST 1989 diff --git a/drivers/csv/testdata/test_timestamp.tsv b/drivers/csv/testdata/test_timestamp.tsv new file mode 100644 index 000000000..ed4dd1fdf --- /dev/null +++ b/drivers/csv/testdata/test_timestamp.tsv @@ -0,0 +1,2 @@ +ANSIC DateHourMinute DateHourMinuteSecond ISO8601 ISO8601Z RFC1123 RFC1123Z RFC3339 RFC3339Nano RFC3339NanoZ RFC3339Z RFC8222 RFC8222Z RFC850 RubyDate UnixDate +Thu Nov 9 15:17:59 1989 1989-11-09 15:17 1989-11-09 15:17:59 1989-11-09T15:17:59.123Z 1989-11-09T15:17:59.123Z Thu, 09 Nov 1989 15:17:59 UTC Thu, 09 Nov 1989 15:17:59 +0000 1989-11-09T15:17:59Z 1989-11-09T15:17:59.1234567Z 1989-11-09T15:17:59.1234567Z 1989-11-09T15:17:59Z 09 Nov 89 15:17 UTC 09 Nov 89 15:17 +0000 Thursday, 09-Nov-89 15:17:59 UTC Thu Nov 09 15:17:59 +0000 1989 Thu Nov 9 15:17:59 UTC 1989 From fc23cf0b80062c8ecdf1a5bdb4641105b62cde9d Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 21:16:14 -0700 Subject: [PATCH 183/195] time testing --- drivers/csv/csv_test.go | 254 ++++++++++++++++--------------------- drivers/sqlite3/sqlite3.go | 9 +- libsq/core/kind/detect.go | 7 +- libsq/core/timez/timez.go | 48 +++++++ libsq/driver/grips.go | 4 +- testh/fixt/fixt.go | 3 + 6 files changed, 171 insertions(+), 154 deletions(-) diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index eeb9aa0e5..411a92e1a 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -3,15 +3,15 @@ package csv_test import ( "context" "fmt" - "github.com/neilotoole/sq/testh/fixt" - "golang.org/x/exp/maps" "os" "path/filepath" "slices" "testing" "time" - "github.com/samber/lo" + "github.com/neilotoole/sq/testh/fixt" + "golang.org/x/exp/maps" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -183,7 +183,7 @@ func TestEmptyAsNull(t *testing.T) { } } -func TestIngestDuplicateColumns(t *testing.T) { +func TestIngest_DuplicateColumns(t *testing.T) { ctx := context.Background() tr := testrun.New(ctx, t, nil) @@ -221,38 +221,9 @@ func TestIngestDuplicateColumns(t *testing.T) { require.Equal(t, wantHeaders, data[0]) } -func TestGenerateDatetimeVals(t *testing.T) { - canonicalTimeUTC := time.Unix(0, fixt.TimestampUnixNano1989).UTC() - _ = canonicalTimeUTC - - names := maps.Keys(timez.TimestampLayouts) - slices.Sort(names) - - for _, loc := range []*time.Location{time.UTC, timez.LosAngeles} { - fmt.Printf("\n\n%s\n\n", loc.String()) - tm := canonicalTimeUTC.In(loc) - - for _, name := range names { - layout := timez.TimestampLayouts[name] - fmt.Fprintf(os.Stdout, "%32s: %s\n", name, tm.Format(layout)) - } - } - - t.Logf("\n\n") -} - -func TestIngestTimestamp(t *testing.T) { +func TestIngest_Kind_Timestamp(t *testing.T) { t.Parallel() - denver, err := time.LoadLocation("America/Denver") - require.NoError(t, err) - _ = denver - lax, err := time.LoadLocation("America/Los_Angeles") - require.NoError(t, err) - _ = lax - - _ = denver - wantNanoUTC := time.Unix(0, fixt.TimestampUnixNano1989).UTC() wantMilliUTC := wantNanoUTC.Truncate(time.Millisecond) wantSecUTC := wantNanoUTC.Truncate(time.Second) @@ -307,6 +278,8 @@ func TestIngestTimestamp(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.file, func(t *testing.T) { + t.Parallel() + th := testh.New(t, testh.OptLongOpen()) src := &source.Source{ Handle: "@tsv/" + tc.file, @@ -327,121 +300,96 @@ func TestIngestTimestamp(t *testing.T) { t.Run(col, func(t *testing.T) { t.Logf("[%d] %s", i, col) assert.Equal(t, kind.Datetime.String(), sink.RecMeta.Kinds()[i].String()) - if gotTime, ok := sink.Recs[0][i].(time.Time); ok { - // REVISIT: If it's a time value, we want to compare UTC times. - // This may actually be a bug. - wantTime := tc.wantVals[i] - t.Logf("wantTime: %s | %s | %d ", wantTime.Format(time.RFC3339Nano), wantTime.Location(), wantTime.Unix()) - t.Logf(" gotTime: %s | %s | %d ", gotTime.Format(time.RFC3339Nano), gotTime.Location(), gotTime.Unix()) - require.True(t, ok) - assert.Equal(t, wantTime.Unix(), gotTime.Unix()) - assert.Equal(t, wantTime.UTC(), gotTime.UTC()) - } else { - assert.EqualValues(t, tc.wantVals[i], sink.Recs[0][i]) - } + wantTime := tc.wantVals[i] + gotTime, ok := sink.Recs[0][i].(time.Time) + require.True(t, ok) + t.Logf( + "wantTime: %s | %s | %d ", + wantTime.Format(time.RFC3339Nano), + wantTime.Location(), + wantTime.Unix(), + ) + t.Logf( + " gotTime: %s | %s | %d ", + gotTime.Format(time.RFC3339Nano), + gotTime.Location(), + gotTime.Unix(), + ) + assert.Equal(t, wantTime.Unix(), gotTime.Unix()) + assert.Equal(t, wantTime.UTC(), gotTime.UTC()) }) } }) } } -func TestDatetime(t *testing.T) { +func TestIngest_Kind_Date(t *testing.T) { t.Parallel() - denver, err := time.LoadLocation("America/Denver") - require.NoError(t, err) - _ = denver - lax, err := time.LoadLocation("America/Los_Angeles") - require.NoError(t, err) - _ = lax + wantDate := time.Date(1989, time.November, 9, 0, 0, 0, 0, time.UTC) - _ = denver - // 1989-11-09T15:17:59.1234567Z - RFC3339Nano - //wantDtNanoUTC := time.Date(1989, 11, 9, 15, 17, 59, 123456700, time.UTC) - wantNanoUTC := time.Unix(0, fixt.TimestampUnixNano1989).UTC() - //wantNanoUTC := timez.MustParse(time.RFC3339Nano, "1989-11-09T15:17:59.1234567Z") + testCases := []struct { + file string + wantHeaders []string + wantVals []time.Time + }{ + { + file: "test_date", + wantHeaders: []string{"Long", "Short", "d-mmm-yy", "mm-dd-yy", "mmmm d, yyyy"}, + wantVals: loz.Make(5, wantDate), + }, + } - //got := wantDtNanoUTC.Format(time.RFC3339Nano) - //t.Logf(got) + for _, tc := range testCases { + tc := tc + t.Run(tc.file, func(t *testing.T) { + t.Parallel() - wantMilliUTC := wantNanoUTC.Truncate(time.Millisecond) - wantSecUTC := wantNanoUTC.Truncate(time.Second) - wantMinUTC := wantNanoUTC.Truncate(time.Minute) + th := testh.New(t, testh.OptLongOpen()) + src := &source.Source{ + Handle: "@tsv/" + tc.file, + Type: csv.TypeTSV, + Location: filepath.Join("testdata", tc.file+".tsv"), + } + src = th.Add(src) + + sink, err := th.QuerySLQ(src.Handle+".data", nil) + require.NoError(t, err) + + assert.Equal(t, tc.wantHeaders, sink.RecMeta.MungedNames()) + require.Len(t, sink.Recs, 1) + t.Log(sink.Recs[0]) + + for i, col := range sink.RecMeta.MungedNames() { + i, col := i, col + t.Run(col, func(t *testing.T) { + t.Logf("[%d] %s", i, col) + assert.Equal(t, kind.Date.String(), sink.RecMeta.Kinds()[i].String()) + gotTime, ok := sink.Recs[0][i].(time.Time) + require.True(t, ok) + wantTime := tc.wantVals[i] + t.Logf("wantTime: %s | %s | %d ", wantTime.Format(time.RFC3339Nano), wantTime.Location(), wantTime.Unix()) + t.Logf(" gotTime: %s | %s | %d ", gotTime.Format(time.RFC3339Nano), gotTime.Location(), gotTime.Unix()) + assert.Equal(t, wantTime.Unix(), gotTime.Unix()) + assert.Equal(t, wantTime.UTC(), gotTime.UTC()) + }) + } + }) + } +} - // 1989-11-09T15:17:59.1234567-07:00 - RFC3339Nano - //wantDtNanoMST := timez.MustParse(time.RFC3339Nano, "1989-11-09T15:17:59.1234567-07:00") - //wantDtNanoMST2 := time.Date(1989, 11, 9, 15, 17, 59, 123456700, denver) - // - //require.Equal(t, wantDtNanoMST.Unix(), wantDtNanoMST2.Unix()) - //require.Equal(t, wantDtNanoMST.UTC(), wantDtNanoMST2.UTC()) - ////got = wantDtNanoMST.Format(time.RFC3339Nano) - ////t.Log(got) - //wantDtMilliMST := wantDtNanoMST.Truncate(time.Millisecond) - //wantDtSecMST := wantDtNanoMST.Truncate(time.Second) - //wantDtMinMST := wantDtNanoMST.Truncate(time.Minute) - // - //t.Logf("wantDtSecMST: %s .... %d", wantDtSecMST.Format(time.RFC1123), wantDtSecMST.Unix()) - - // FIXME: repair this stuff +func TestIngest_Kind_Time(t *testing.T) { + t.Parallel() testCases := []struct { file string wantHeaders []string - wantKinds []kind.Kind - wantVals []any + wantVals []string }{ - //{ - // file: "test_date", - // wantHeaders: []string{"Long", "Short", "d-mmm-yy", "mm-dd-yy", "mmmm d, yyyy"}, - // wantKinds: loz.Make(5, kind.Date), - // wantVals: lo.ToAnySlice(loz.Make(5, - // time.Date(1989, time.November, 9, 0, 0, 0, 0, time.UTC))), - //}, - //{ - // file: "test_time", - // wantHeaders: []string{"time1", "time2", "time3", "time4", "time5", "time6"}, - // wantKinds: loz.Make(6, kind.Time), - // wantVals: []any{"15:17:00", "15:17:00", "15:17:00", "15:17:00", "15:17:00", "15:17:59"}, - //}, { - file: "test_datetime", - wantHeaders: []string{ - "ANSIC", - "DateHourMinute", - "DateHourMinuteSecond", - "ISO8601", - "ISO8601Z", - "RFC1123", - "RFC1123Z", - "RFC3339", - "RFC3339Nano", - "RFC3339NanoZ", - "RFC3339Z", - "RFC8222", - "RFC8222Z", - "RFC850", - "RubyDate", - "UnixDate", - }, - wantKinds: loz.Make(20, kind.Datetime), - wantVals: lo.ToAnySlice([]time.Time{ - wantSecUTC, // ANSIC - wantMinUTC, // DateHourMinute - wantSecUTC, // DateHourMinuteSecond - wantMilliUTC, // ISO8601 - wantMilliUTC, // ISO8601Z - wantSecUTC, // RFC1123 - wantSecUTC, // RFC1123Z - wantSecUTC, // RFC3339 - wantNanoUTC, // RFC3339Nano - wantNanoUTC, // RFC3339NanoZ - wantSecUTC, // RFC3339Z - wantMinUTC, // RFC8222 - wantMinUTC, // RFC8222Z - wantSecUTC, // RFC850 - wantSecUTC, // RubyDate - wantSecUTC, // UnixDate - }), + file: "test_time", + wantHeaders: []string{"time1", "time2", "time3", "time4", "time5", "time6"}, + wantVals: []string{"15:17:00", "15:17:00", "15:17:00", "15:17:00", "15:17:00", "15:17:59"}, }, } @@ -469,21 +417,39 @@ func TestDatetime(t *testing.T) { i, col := i, col t.Run(col, func(t *testing.T) { t.Logf("[%d] %s", i, col) - assert.Equal(t, tc.wantKinds[i].String(), sink.RecMeta.Kinds()[i].String()) - if gotTime, ok := sink.Recs[0][i].(time.Time); ok { - // REVISIT: If it's a time value, we want to compare UTC times. - // This may actually be a bug. - wantTime, ok := tc.wantVals[i].(time.Time) - t.Logf("wantTime: %s | %s | %d ", wantTime.Format(time.RFC3339Nano), wantTime.Location(), wantTime.Unix()) - t.Logf(" gotTime: %s | %s | %d ", gotTime.Format(time.RFC3339Nano), gotTime.Location(), gotTime.Unix()) - require.True(t, ok) - assert.Equal(t, wantTime.Unix(), gotTime.Unix()) - assert.Equal(t, wantTime.UTC(), gotTime.UTC()) - } else { - assert.EqualValues(t, tc.wantVals[i], sink.Recs[0][i]) - } + assert.Equal(t, kind.Time.String(), sink.RecMeta.Kinds()[i].String()) + gotTime, ok := sink.Recs[0][i].(string) + require.True(t, ok) + wantTime := tc.wantVals[i] + t.Logf("wantTime: %s", wantTime) + t.Logf(" gotTime: %s", gotTime) + assert.Equal(t, wantTime, gotTime) }) } }) } } + +// TestGenerateTimestampVals is a utility test that prints out a bunch +// of timestamp values in various time formats and locations. +// It was used to generate values for use in testdata/test_timestamp.tsv. +// It can probably be deleted when we're satisfied with date/time testing. +func TestGenerateTimestampVals(t *testing.T) { + canonicalTimeUTC := time.Unix(0, fixt.TimestampUnixNano1989).UTC() + _ = canonicalTimeUTC + + names := maps.Keys(timez.TimestampLayouts) + slices.Sort(names) + + for _, loc := range []*time.Location{time.UTC, timez.LosAngeles, timez.Denver} { + fmt.Fprintf(os.Stdout, "\n\n%s\n\n", loc.String()) + tm := canonicalTimeUTC.In(loc) + + for _, name := range names { + layout := timez.TimestampLayouts[name] + fmt.Fprintf(os.Stdout, "%32s: %s\n", name, tm.Format(layout)) + } + } + + t.Logf("\n\n") +} diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index 2650da3ca..17a70a2a4 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -15,6 +15,8 @@ import ( "sync" "time" + "github.com/neilotoole/sq/libsq/core/ioz" + _ "github.com/mattn/go-sqlite3" // Import for side effect of loading the driver "github.com/shopspring/decimal" @@ -1033,9 +1035,12 @@ func NewScratchSource(ctx context.Context, fpath string) (src *source.Source, cl } clnup = func() error { + if journal := filepath.Join(fpath, ".db-journal"); ioz.FileAccessible(journal) { + lg.WarnIfError(log, "Delete sqlite3 db journal file", os.Remove(journal)) + } + log.Debug("Delete sqlite3 scratchdb file", lga.Src, src, lga.Path, fpath) - lg.WarnIfError(log, "Delete sqlite3 db journal file", os.RemoveAll(filepath.Join(fpath, ".db-journal"))) - if err := os.RemoveAll(fpath); err != nil { + if err := os.Remove(fpath); err != nil { log.Warn("Delete sqlite3 scratchdb file", lga.Err, err) return errz.Err(err) } diff --git a/libsq/core/kind/detect.go b/libsq/core/kind/detect.go index a7b1e64e1..a0d6ff947 100644 --- a/libsq/core/kind/detect.go +++ b/libsq/core/kind/detect.go @@ -242,12 +242,7 @@ func (d *Detector) doSampleString(s string) { return nil, nil //nolint:nilnil } - t, err := time.Parse(format, s) - if err != nil { - return nil, errz.Err(err) - } - - return t, nil + return errz.Return(time.Parse(format, s)) } } } diff --git a/libsq/core/timez/timez.go b/libsq/core/timez/timez.go index 4e6842fba..8d49bbf86 100644 --- a/libsq/core/timez/timez.go +++ b/libsq/core/timez/timez.go @@ -143,3 +143,51 @@ func ParseDateOrTimestampUTC(s string) (time.Time, error) { t, err := ParseDateUTC(s) return t.UTC(), err } + +// MustParse is like time.Parse, but panics on error. +func MustParse(layout, value string) time.Time { + t, err := time.Parse(layout, value) + if err != nil { + panic(err) + } + return t +} + +// TimestampLayouts is a map of timestamp layout names to layout string. +var TimestampLayouts = map[string]string{ + "RFC3339": time.RFC3339, + "RFC3339Z": RFC3339Z, + "ISO8601": ISO8601, + "ISO8601Z": ISO8601Z, + "RFC3339Nano": time.RFC3339Nano, + "RFC3339NanoZ": RFC3339NanoZ, + "ANSIC": time.ANSIC, + "UnixDate": time.UnixDate, + "RubyDate": time.RubyDate, + "RFC8222": time.RFC822, + "RFC8222Z": time.RFC822Z, + "RFC850": time.RFC850, + "RFC1123": time.RFC1123, + "RFC1123Z": time.RFC1123Z, + "Stamp": time.Stamp, + "StampMilli": time.StampMilli, + "StampMicro": time.StampMicro, + "StampNano": time.StampNano, + "DateHourMinuteSecond": DateHourMinuteSecond, + "DateHourMinute": DateHourMinute, + "ExcelDatetimeMDYSeconds": ExcelDatetimeMDYSeconds, + "ExcelDatetimeMDYNoSeconds": ExcelDatetimeMDYNoSeconds, +} + +var ( + LosAngeles = mustLoadLocation("America/Los_Angeles") + Denver = mustLoadLocation("America/Denver") +) + +func mustLoadLocation(name string) *time.Location { + loc, err := time.LoadLocation(name) + if err != nil { + panic(err) + } + return loc +} diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 9237be198..813fbc73a 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -295,13 +295,13 @@ func (gs *Grips) openIngestGripCache(ctx context.Context, src *source.Source, return nil, err } if foundCached { - log.Debug("Ingest cache HIT: found cached copy of source", + log.Info("Ingest cache HIT: found cached copy of source", lga.Src, src, "cached", impl.Source(), ) return impl, nil } - log.Debug("Ingest cache MISS: no cache for source", lga.Src, src) + log.Info("Ingest cache MISS: no cache for source", lga.Src, src) var cleanFn func() error impl, cleanFn, err = gs.openNewCacheGrip(ctx, src) diff --git a/testh/fixt/fixt.go b/testh/fixt/fixt.go index 841dd4317..a99263ae4 100644 --- a/testh/fixt/fixt.go +++ b/testh/fixt/fixt.go @@ -17,6 +17,9 @@ import ( // These consts are test fixtures for various data types. const ( + // TimestampUnixNano1989 is 1989-11-09T15:17:59.1234567Z. + TimestampUnixNano1989 = 626627879123456700 + Text string = "seven" TextZ string = "" Int int64 = 7 From a450a8d773d9b3b2b6648e56fca8d8c6f03f9461 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sat, 13 Jan 2024 21:31:39 -0700 Subject: [PATCH 184/195] Update xlsx datetime test data --- cli/cmd_x.go | 16 ++++---- drivers/xlsx/testdata/datetime.xlsx | Bin 11454 -> 11469 bytes drivers/xlsx/xlsx_test.go | 57 +++++++++++++--------------- 3 files changed, 33 insertions(+), 40 deletions(-) diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 0ab9fb385..2d4c08357 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -95,7 +95,6 @@ func execXProgress(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() log := lg.FromContext(ctx) ru := run.FromContext(ctx) - _ = ru d := time.Second * 5 pb := progress.FromContext(ctx) @@ -103,22 +102,21 @@ func execXProgress(cmd *cobra.Command, _ []string) error { defer bar.Stop() select { - // case <-pressEnter(): - // bar.Stop() - // pb.Stop() - // fmt.Fprintln(ru.Out, "\nENTER received") + case <-pressEnter(): + bar.Stop() + pb.Stop() + fmt.Fprintln(ru.Out, "\nENTER received") case <-ctx.Done(): - // bar.Stop() - // pb.Stop() + bar.Stop() + pb.Stop() fmt.Fprintln(ru.Out, "Context done") case <-time.After(d + time.Second*5): - // bar.Stop() + bar.Stop() log.Warn("timed out, about to print something") fmt.Fprintln(ru.Out, "Really timed out") log.Warn("done printing") } - // bar.EwmaIncrInt64(rand.Int63n(5)+1, time.Since(start)) fmt.Fprintln(ru.Out, "exiting") return ctx.Err() } diff --git a/drivers/xlsx/testdata/datetime.xlsx b/drivers/xlsx/testdata/datetime.xlsx index 8ebddfd4ce40a7ba53810d907e083b36ce9d0743..d150561e043c1de798b8e87f9219e511dfe9498c 100644 GIT binary patch delta 3046 zcmVge^GCvI1qk6>HY)qJGBiCA*f`vA&IWi?y1`A z?mn(uaEeufS=%&J)&BPzlah9mm)`D0#Kz3%1t03%w&Vs?fmB618Sfc4oU^;tMBTJ*L z-nk2a#s|`8@l%2+{V4^*c@Ex&f5fm657O80$jT(Y$SOHvB`;u<{5Bc|X^mRRAbG!t zHal{Q(DHI8mH2XyZos?lZ{GMyDHJGf-KLUGR5_Z%QSZf14@}>!da1xl$$~5pULNIq zZGuYdc(2?I2zi1B$0S)qGw%t)c(0r=Eq9{OtpgoC$Mw6K8?lR8gTjMPf8u*Djey1i zeL<|S@-pMM_L{po=#9z|{%j$5+zP3*(^?7adlK94{;G`hU5y822kapHiaU;9X_lWi z$bV2RFMTR7t86R1&ry;v0#Q9qXO<>#+>h3;vT}8f zSQJOcXMYKKR|DlJVyW*B6M1SmdLlCxt!a^lEM661K^P0ubrFSRf3{t0i}hl?UBx3M z{EE{!TY=lS2kqDoP

EBVdgC$?NkRqrhbTE#|7CKCDy!SHAR>Z*=N=f(s+@h#Qak zm`pnawxr z&1&XL@4q_P|1o>VCEm)W&zWagbK`CdFYn!{MNh7wR z>1J&H@x17TrcrR5u5ip!GB!U*VSbtX{@dt{`O7ZrMJ&mHVo77lKP0JXB+>9@ld>EyW9aug@^NIm=kV*opQoCU4f5@ryizfN2P*eGNvL0Oxd$@uOtPRN6T_7c zcXTy2{|T=yror@T-V6N2pf~gVS?|*6&wE#k%fZw^{maEP{MVd}RupDf^vcM|dTd^z z;jM}8N;EBiVjXi{qB>jtcng^70TMK zV9KlR!K)r*KY!gW7gM=#&oRf7QO@2B*oeWeMTW((1`P$H$fy;xT*+lY;~EX?s-UQN6!sIH>l*=5CK&fh$GH%7<}q?0K-9}bbq>gAe6x{6-@H9BjUhw zLl;?r7lck<{%CPAYeZl8Vd%U4fvchC#|?X0^JRAE`Pw%$Kt?$94g|l)SgD78+qhzqhbys_4AP1Z%vN8f7Buu&N!8v|D_S& zW;JlcX-^}-K!Z-zz_Fw~jQ}^(z)`0?jR4na${Sj_(Pwx=?r^@LX<@{0Q55$pIBw+S z1x)7Q3Mos)nTm0^JE^BFe8wQK#XA-Z+?p}29>5fW_brL=U1qSH!o5v?nHZZH%Xy9| zH;pZn|6))P&oe46nBwybr*JdfKMQA7UsooRkscNa50JVw*#Q6m#*?rf8-Ik9O^=%} z5QguS`X4M`y95J)g$P(xsiLmbq*X~Sy}959OW3B^td@UY!)E(o8`U>Ie&_Yfc;x)y znRVca43Y|7z+i(QAiNe0<<|mU&Id1rKq|}|+zCz!I1vdS%C~QFDHYJMTo$lZs^9s( ztXsmc+=!lVeey#L3@g1Iu7AGl2Z9^f5~5h=hX^G;!<0i%i;=4WZuKoAr$0xszkvjn zxuj*T%1?q{b6=IYKkr{y)rvuxu<;mU-ZYs!2t1T|X=ZXQZW@PPyD-q)fk`hjMo~rr z6Gpq>{>#xtd*(6-T`+RNtqaC3n3(YEVIM`2RrC#iGv$xf=QKf9X@58Q33FjaygyZs zdlOzbeg1$Iv9cMzPfmW=r!jKm9t5G)5uMD8t=ubxS#NqQ@JPFkuAJ0^RnijUsx`>c z%vEhR;8g?WIP-dk*Tz-Vif$Vx_txEWYd?x{I8r{^&Hwkb zBDLVn%KKZ+e+_`^ELIIeYnpY)j4k;}K^{UoWQ~CC<{d!U($P8aPfssrYwFg;?~q<* ztvB0OQ{ZqkB~BlG?I0UG&53&Y>GdD;zciP6wE6D|S&W@?;+iL&siF1tKV1F-v!Nh^ z1_b#LaUZiNCnW)Ye3DU5!ypue-%b1v4euSaUDS}aOVsSiMH9DZd?RqX5eqbccKi3z zcB?g0AK{$)an23-Vz27x6ReX)7buw#6oKYONWCr4X1$DWP~<$*g4IStfezr%qP)Dw zc|*Ce@MKK`)=O|vAZbT=Q=pyq4aL~;9Vq5z!J&g#8LOCo4@bMj4dd@@3pgd@3M=qT zFwbxYjGK{&j!3~rWbna578qx7UMJ?$?1{nK+g7hbDA7#+|YSAH>nU zcXBkgZ98kTK1Rr&;Mdi|b6+)005#9000;O000000000000000#~G91 z85)zWD-atGkh(S50RRBT1pojP0000000000000000IQSPD?tIalN&5L0alY%EF&BF o5pf?+0RRAd0ssIJ000000000000000040-@EF=c2DF6Tf00-g1)Bpeg delta 3041 zcmZ8jX*3j!0v%(?HY!>AB$TlYjj?1)G$z|v8X|kLWh^6%HOrX!qOufb5Q9RpdVO=Z_o+el3~20<(P?j=)!`p) zTg5>SJ5Dg)wk;rq#{9a^o$8+UP*EM6xW34o`<#(g23moPbl>w|#0PDL{P@*3cBT5o z2^E$!U@;;k(C4?9kH<*6-|vgR+PQH(X4ScWI9I#X#L-;JgZgcl*N_&s+!PI(5EjEh zu-WVK$gADq=veP1Pk)z~jUn7ovij3R2b$?s+&(~=e< z#UHn6;_t9_aczrVIAT)3+T5+V)WNF%`)eP5d^SKKd3I01wROR4a3)w@DzTYZ`}#LWD+p6AHLGn+c2p3q z(~Jvk;HXGl>_tQU!!wLmkk z?01Sf+zd2zxI~E$Nt?#8kU2NM$nKrK_Z9uUtuq@s5p%L;LV=Fe#TkKSO8JkuzS&gb zDJyLv1;c(4h1DWb;w!hamX4k9Pcvu09R{>z=tBd|8=Wv_;^(CA_-F<86=R{d(4rr% z5jo@6l?4ToD`n^?qnA)g`FvR6`Bo?Bur%Gb2EJ__3P zTb&2AP|6IL15cN?#j+BEO2N4@Yg<>3`mzFVjiWdI0|j1a1M> zIV3RN>Pq4wUzo{+-dYv;5U)pa#b4gd!kDG+6e8z-WArj5LY~^M)>EK@t#Xb%ls6?w zbf&?U5)3=Lx001&^@{dzE5A2MNSOUv> z|Ic@<(OgWrS5gtcOZNpiSjUp-@5O`mmj5KNJ!-ojSI492jJ79Ri+Z@3n>%gEB*e0p z&R9_-?MWnzxlXtt^EC6;nWd=n4f`<#(cVmm{vy`1(Q%`mDiq+o@cQJZoWs`*Jo2&xqDY97dEE1Ca7t|RNPYhZU!2)C^4Iy$|F;? z+v9FBk=kRN%`2I0DxS~pq%k@;*UeF)uG*B+f7TpJ-bnP9N*U_U`BW5HyWq5U_x5wA z4a?oS^DZIZP#jiaD`5T5$Oqq#xn7h6f%r9OqBzfGQidD+51d(IgTu4Ch&s>ZgMxgqyu z4z*sQ6?{3DW11JwUDUPfnk_zAzKicjJQdL~y{1E8^W#CeEG>J*N5P@AtO>9huPytA zpF=O|h+-mfZl##9by6((>TEP72#z$WAPlx-d(d$%PkxKbIDuhemcThY56$*tHcvRj z@=^OZaqR-I{PdE3*i;->poYxxwg_G%O;iu=CSgJF#jh_~Z9MuV(pTcp=Ti}jb@umT zp{nP-R^AZ$6$D6!vH{`aC~*^V90ng|$%9g`g^}uiQ<9TOhRI=VMp>ejFfp5F9j%!R zTqVrGp<|>9j(prz$-^+zM;3Wg^zJ=rqh>NHU}EqHxE={LXjfIqG@>J`7b_j7U53Q1 z1BQ(Q&jL2S&aWnFoeRA`T!0B8D35k>+^MkI!TMUkRMpN? zcJ;ft^%?Eyc7_q2*RU2MnRI>d<}xm`ii@zEnO{;?Dbl7>&yMWgm7WR89N@D+`EPr# zDbyfBgfp+6TpymQ9U0W+wtCjTKJlt_J3J@}trndE4-#N|C7}b_buVapCQ8Y6U3JQcAAiGcPd1xnqK?_bhpqga4ik_`ImLw?2B&qp`p)_i33?QlnESDiTM z4e!80gJi6doo?D|{k@Fm(+nno4ppHPAd2oLX z$`DUWl2`|d{S7x>cCA@CZNr>mkDZ}91}3(c7&n<1A8?Tf#4x2pF?uw%pvNPryp%S$ zAHlfea~UgjS7*ID!5r6T^q^7p67{v>9Gz@|5qTX@5x#&6;2He}l8!1E;)i{BguOA9 z#3(dN>eCKg46v1d6uhSxOCS zHLa=u#lMuRABz*=385W#IxHvHxnseGK68do^C=lyrhh9GZKnhFgr?E2?I;{#m!sVEw;##iom)Xt6bBEin2{47p)w?@mY8YMSpyujs z@nK`uQ=7yslXGk7hv0gu*IN1IUe(Nlyefc3C!?rfx|uuL{;BasC@=r=d5YgF|sVBl}oIN2dHES@H5=gnDsv4nRMM~ks>%TtU8Yb`O zb;OEP|75~@-xj_rsbeV_iNtl8?GW})QbPM&u+teXBb*rDT4yQK5m`hIdb|2XxuW0k z_2GdRy9bpMU-OAJzLC4Z$3xe6^Jk;3k^&blJ=@OeHB2n|3-3OH9!Gu>X@Fqj-xo^@ zMk!O1XQPB#EGsm5#n^UQg&f z;K~Ehni0Y9Chjx|+P2GExA)C>^D;M>EdgX`>GGZ_jVTIJd7FEq#7k#gWR8X;r{MH) z+9$>}Id656h^FPvv2yO@jO`eKMH}1;m_vVi4qaxPN*TDc*0anBsFXquvr@Ssnw=`TuL^Z&$nY{s zt@|3vF&k0ea!u1PY?E1&AF2A1ZG|MXj_sX|2;5W$(S8l!`Qsg)t4BE$4>4ba-%SPE zGGN8GXG_yB7l5^?H@{IH~`{UiPg`F7zp^mr3PV5;!L_`De*V!6ubf37qE zV$bD99EM!r`b?BiH9cT=Rn-&!{R>9{fa3@A`v){)>4E?M^ZdupCoZbOxhjYdH3Kel hqMw@liN71XbTH*OAOLU>_dk#zmZ^dH3RV7={sk8Zu+0Df diff --git a/drivers/xlsx/xlsx_test.go b/drivers/xlsx/xlsx_test.go index f5c9b4491..4fe46b835 100644 --- a/drivers/xlsx/xlsx_test.go +++ b/drivers/xlsx/xlsx_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/neilotoole/sq/testh/fixt" + "github.com/samber/lo" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" @@ -441,23 +443,16 @@ func TestDates(t *testing.T) { func TestDatetime(t *testing.T) { t.Parallel() - denver, err := time.LoadLocation("America/Denver") - require.NoError(t, err) - src := &source.Source{ Handle: "@excel/datetime", Type: xlsx.Type, Location: "testdata/datetime.xlsx", } - wantDtNanoUTC := time.Date(1989, 11, 9, 15, 17, 59, 123456700, time.UTC) - wantDtMilliUTC := wantDtNanoUTC.Truncate(time.Millisecond) - wantDtSecUTC := wantDtNanoUTC.Truncate(time.Second) - wantDtMinUTC := wantDtNanoUTC.Truncate(time.Minute) - wantDtNanoMST := time.Date(1989, 11, 9, 15, 17, 59, 123456700, denver) - wantDtMilliMST := wantDtNanoMST.Truncate(time.Millisecond) - wantDtSecMST := wantDtNanoMST.Truncate(time.Second) - wantDtMinMST := wantDtNanoMST.Truncate(time.Minute) + wantNanoUTC := time.Unix(0, fixt.TimestampUnixNano1989).UTC() + wantMilliUTC := wantNanoUTC.Truncate(time.Millisecond) + wantSecUTC := wantNanoUTC.Truncate(time.Second) + wantMinUTC := wantNanoUTC.Truncate(time.Minute) testCases := []struct { sheet string @@ -504,26 +499,26 @@ func TestDatetime(t *testing.T) { }, wantKinds: loz.Make(20, kind.Datetime), wantVals: lo.ToAnySlice([]time.Time{ - wantDtSecUTC, // ANSIC - wantDtMinUTC, // DateHourMinute - wantDtMinUTC, // DateHourMinuteSecond - wantDtMilliMST, // ISO8601 - wantDtMilliUTC, // ISO8601Z - wantDtSecMST, // RFC1123 - wantDtSecMST, // RFC1123Z - wantDtSecMST, // RFC3339 - wantDtNanoMST, // RFC3339Nano - wantDtNanoUTC, // RFC3339NanoZ - wantDtSecUTC, // RFC3339Z - wantDtMinMST, // RFC8222 - wantDtMinUTC, // RFC8222Z - wantDtSecMST, // RFC850 - wantDtSecMST, // RubyDate - wantDtMinUTC, // Stamp - wantDtMinUTC, // StampMicro - wantDtMinUTC, // StampMilli - wantDtMinUTC, // StampNano - wantDtSecMST, // UnixDate + wantSecUTC, // ANSIC + wantMinUTC, // DateHourMinute + wantMinUTC, // DateHourMinuteSecond + wantMilliUTC, // ISO8601 + wantMilliUTC, // ISO8601Z + wantSecUTC, // RFC1123 + wantSecUTC, // RFC1123Z + wantSecUTC, // RFC3339 + wantNanoUTC, // RFC3339Nano + wantNanoUTC, // RFC3339NanoZ + wantSecUTC, // RFC3339Z + wantMinUTC, // RFC8222 + wantMinUTC, // RFC8222Z + wantSecUTC, // RFC850 + wantSecUTC, // RubyDate + wantMinUTC, // Stamp + wantMinUTC, // StampMicro + wantMinUTC, // StampMilli + wantMinUTC, // StampNano + wantSecUTC, // UnixDate }), }, } From c78e57d94ea0d9e639e67321c877d84cf83010f5 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 00:20:59 -0700 Subject: [PATCH 185/195] Parallel tests --- drivers/json/ingest_test.go | 30 ++++++++++++++++++++++++------ drivers/json/internal_test.go | 16 ++++++++++------ testh/testh.go | 18 ++++++++---------- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/drivers/json/ingest_test.go b/drivers/json/ingest_test.go index 16b18a6b1..0c4b8833c 100644 --- a/drivers/json/ingest_test.go +++ b/drivers/json/ingest_test.go @@ -20,7 +20,9 @@ import ( "github.com/neilotoole/sq/testh/tu" ) -func TestImportJSONL_Flat(t *testing.T) { +func TestIngestJSONL_Flat(t *testing.T) { + t.Parallel() + // Either fpath (testdata file path) or input should be provided. testCases := []struct { fpath string @@ -75,6 +77,8 @@ func TestImportJSONL_Flat(t *testing.T) { tc := tc t.Run(tu.Name(i, tc.fpath, tc.input), func(t *testing.T) { + t.Parallel() + openFn := func(ctx context.Context) (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(tc.input)), nil } @@ -86,9 +90,9 @@ func TestImportJSONL_Flat(t *testing.T) { } th, src, _, grip, _ := testh.NewWith(t, testsrc.EmptyDB) - job := json.NewImportJob(src, openFn, grip, 0, true) + job := json.NewIngestJob(src, openFn, grip, 0, true) - err := json.ImportJSONL(th.Context, job) + err := json.IngestJSONL(th.Context, job) if tc.wantErr { require.Error(t, err) return @@ -105,15 +109,17 @@ func TestImportJSONL_Flat(t *testing.T) { } } -func TestImportJSON_Flat(t *testing.T) { +func TestIngestJSON_Flat(t *testing.T) { + t.Parallel() + openFn := func(context.Context) (io.ReadCloser, error) { return os.Open("testdata/actor.json") } th, src, _, grip, _ := testh.NewWith(t, testsrc.EmptyDB) - job := json.NewImportJob(src, openFn, grip, 0, true) + job := json.NewIngestJob(src, openFn, grip, 0, true) - err := json.ImportJSON(th.Context, job) + err := json.IngestJSON(th.Context, job) require.NoError(t, err) sink, err := th.QuerySQL(src, nil, "SELECT * FROM data") @@ -122,6 +128,8 @@ func TestImportJSON_Flat(t *testing.T) { } func TestScanObjectsInArray(t *testing.T) { + t.Parallel() + var ( m1 = []map[string]any{{"a": float64(1)}} m2 = []map[string]any{{"a": float64(1)}, {"a": float64(2)}} @@ -188,6 +196,8 @@ func TestScanObjectsInArray(t *testing.T) { tc := tc t.Run(tu.Name(i, tc.in), func(t *testing.T) { + t.Parallel() + r := bytes.NewReader([]byte(tc.in)) gotObjs, gotChunks, err := json.ScanObjectsInArray(r) if tc.wantErr { @@ -207,6 +217,8 @@ func TestScanObjectsInArray(t *testing.T) { } func TestScanObjectsInArray_Files(t *testing.T) { + t.Parallel() + testCases := []struct { fname string wantCount int @@ -220,6 +232,8 @@ func TestScanObjectsInArray_Files(t *testing.T) { tc := tc t.Run(tu.Name(tc.fname), func(t *testing.T) { + t.Parallel() + f, err := os.Open(tc.fname) require.NoError(t, err) defer f.Close() @@ -233,6 +247,8 @@ func TestScanObjectsInArray_Files(t *testing.T) { } func TestColumnOrderFlat(t *testing.T) { + t.Parallel() + testCases := []struct { in string want []string @@ -261,6 +277,8 @@ func TestColumnOrderFlat(t *testing.T) { tc := tc t.Run(tu.Name(i, tc.in), func(t *testing.T) { + t.Parallel() + require.True(t, stdj.Valid([]byte(tc.in))) gotCols, err := json.ColumnOrderFlat([]byte(tc.in)) diff --git a/drivers/json/internal_test.go b/drivers/json/internal_test.go index c35e56ac3..9f0e8d868 100644 --- a/drivers/json/internal_test.go +++ b/drivers/json/internal_test.go @@ -17,16 +17,16 @@ import ( // export for testing. var ( - ImportJSON = ingestJSON - ImportJSONA = ingestJSONA - ImportJSONL = ingestJSONL + IngestJSON = ingestJSON + IngestJSONA = ingestJSONA + IngestJSONL = ingestJSONL ColumnOrderFlat = columnOrderFlat - NewImportJob = newImportJob + NewIngestJob = newIngestJob ) -// newImportJob is a constructor for the unexported ingestJob type. +// newIngestJob is a constructor for the unexported ingestJob type. // If sampleSize <= 0, a default value is used. -func newImportJob(fromSrc *source.Source, openFn source.FileOpenFunc, destGrip driver.Grip, sampleSize int, +func newIngestJob(fromSrc *source.Source, openFn source.FileOpenFunc, destGrip driver.Grip, sampleSize int, flatten bool, ) ingestJob { if sampleSize <= 0 { @@ -43,6 +43,8 @@ func newImportJob(fromSrc *source.Source, openFn source.FileOpenFunc, destGrip d } func TestDetectColKindsJSONA(t *testing.T) { + t.Parallel() + testCases := []struct { tbl string wantKinds []kind.Kind @@ -56,6 +58,8 @@ func TestDetectColKindsJSONA(t *testing.T) { tc := tc t.Run(tc.tbl, func(t *testing.T) { + t.Parallel() + f, err := os.Open(fmt.Sprintf("testdata/%s.jsona", tc.tbl)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, f.Close()) }) diff --git a/testh/testh.go b/testh/testh.go index b1c89bc58..08d9e7ccc 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -62,7 +62,7 @@ import ( // defaultDBOpenTimeout is the timeout for tests to open (and ping) their DBs. // This should be a low value, because, well, we can either connect // or not. -const defaultDBOpenTimeout = time.Second * 5 +const defaultDBOpenTimeout = time.Second * 5 //nolint:unused // Option is a functional option type used with New to // configure the helper. @@ -77,7 +77,8 @@ type Option func(h *Helper) // Most tests don't need this. func OptLongOpen() Option { return func(h *Helper) { - h.dbOpenTimeout = time.Second * 180 + // FIXME: Delete OptLongOpen entirely + // h.dbOpenTimeout = time.Second * 180 } } @@ -117,17 +118,16 @@ type Helper struct { Cleanup *cleanup.Cleanup - dbOpenTimeout time.Duration + dbOpenTimeout time.Duration //nolint:unused } // New returns a new Helper. The helper's Close func will be // automatically invoked via t.Cleanup. func New(t testing.TB, opts ...Option) *Helper { h := &Helper{ - T: t, - Log: lgt.New(t), - Cleanup: cleanup.New(), - dbOpenTimeout: defaultDBOpenTimeout, + T: t, + Log: lgt.New(t), + Cleanup: cleanup.New(), } ctx, cancelFn := context.WithCancel(context.Background()) @@ -403,9 +403,7 @@ func (h *Helper) NewCollection(handles ...string) *source.Collection { // same driver.Grip instance. The opened driver.Grip will be closed // during h.Close. func (h *Helper) Open(src *source.Source) driver.Grip { - ctx, cancelFn := context.WithTimeout(h.Context, h.dbOpenTimeout) - defer cancelFn() - + ctx := h.Context grip, err := h.Grips().Open(ctx, src) require.NoError(h.T, err) From 7dcad3a0b4d225ace06e83a9a489e450e53a2576 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 07:44:54 -0700 Subject: [PATCH 186/195] Text record writer now shows a progress bar while preparing records --- cli/cmd_inspect_test.go | 2 +- cli/output/adapter.go | 3 +-- cli/output/tablew/metadatawriter.go | 4 +-- cli/output/tablew/recordwriter.go | 38 ++++++++++++++++++++++++--- cli/output/tablew/tablew.go | 4 +-- drivers/csv/csv_test.go | 8 +++--- drivers/userdriver/userdriver_test.go | 2 +- drivers/xlsx/xlsx_test.go | 14 +++++----- libsq/core/ioz/ioz.go | 4 +-- libsq/libsq_test.go | 2 +- testh/testh.go | 22 ---------------- 11 files changed, 56 insertions(+), 47 deletions(-) diff --git a/cli/cmd_inspect_test.go b/cli/cmd_inspect_test.go index 8382bbc87..7164f299a 100644 --- a/cli/cmd_inspect_test.go +++ b/cli/cmd_inspect_test.go @@ -68,7 +68,7 @@ func TestCmdInspect_json_yaml(t *testing.T) { //nolint:tparallel t.Parallel() tu.SkipWindowsIf(t, tc.handle == sakila.XLSX, "XLSX too slow on windows workflow") - th := testh.New(t, testh.OptCaching(true)) + th := testh.New(t) src := th.Source(tc.handle) tr := testrun.New(th.Context, t, nil).Hush().Add(*src) diff --git a/cli/output/adapter.go b/cli/output/adapter.go index fa62f5ebf..39776eb4f 100644 --- a/cli/output/adapter.go +++ b/cli/output/adapter.go @@ -59,9 +59,8 @@ func NewRecordWriterAdapter(ctx context.Context, rw RecordWriter) *RecordWriterA func (w *RecordWriterAdapter) Open(ctx context.Context, cancelFn context.CancelFunc, recMeta record.Meta, ) (chan<- record.Record, <-chan error, error) { - w.cancelFn = cancelFn - lg.FromContext(ctx).Debug("Open RecordWriterAdapter", "fields", recMeta) + w.cancelFn = cancelFn err := w.rw.Open(ctx, recMeta) if err != nil { diff --git a/cli/output/tablew/metadatawriter.go b/cli/output/tablew/metadatawriter.go index 46b948e10..34023f9ec 100644 --- a/cli/output/tablew/metadatawriter.go +++ b/cli/output/tablew/metadatawriter.go @@ -128,7 +128,7 @@ func (w *mdWriter) doSourceMetaNoSchema(md *metadata.Source) error { } w.tbl.tblImpl.SetHeader(headers) - return w.tbl.renderRow(context.TODO(), row) + return w.tbl.writeRow(context.TODO(), row) } func (w *mdWriter) printTablesVerbose(tbls []*metadata.Table) error { @@ -261,7 +261,7 @@ func (w *mdWriter) doSourceMetaFull(md *metadata.Source) error { } w.tbl.tblImpl.SetHeader(headers) - if err := w.tbl.renderRow(context.TODO(), row); err != nil { + if err := w.tbl.writeRow(context.TODO(), row); err != nil { return err } diff --git a/cli/output/tablew/recordwriter.go b/cli/output/tablew/recordwriter.go index a468ce509..368dce76d 100644 --- a/cli/output/tablew/recordwriter.go +++ b/cli/output/tablew/recordwriter.go @@ -6,6 +6,8 @@ import ( "sync" "github.com/neilotoole/sq/cli/output" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/record" ) @@ -14,6 +16,7 @@ type recordWriter struct { tbl *table recMeta record.Meta rowCount int + bar *progress.Bar } // NewRecordWriter returns a RecordWriter for text table output. @@ -25,18 +28,45 @@ func NewRecordWriter(out io.Writer, pr *output.Printing) output.RecordWriter { } // Open implements output.RecordWriter. -func (w *recordWriter) Open(_ context.Context, recMeta record.Meta) error { +func (w *recordWriter) Open(ctx context.Context, recMeta record.Meta) error { + w.mu.Lock() + defer w.mu.Unlock() + w.recMeta = recMeta + + // We show a progress bar, because this writer batches all records and writes + // them together at the end. A better-behaved writer would stream records + // as they arrive (or at least batch them in smaller chunks). This will + // probably be fixed at some point, but there's a bit of a catch. The table + // determines the width of each column based on the widest value seen for that + // column. So, if we stream records as they arrive, we can't know the maximum + // width of each column until all records have been received. Thus, + // periodically flushing the output may result in inconsistent bar widths for + // subsequent batches. This is probably something that we'll have to live + // with. After all, this writer is intended for human/interactive use, and + // if the number of records is huge (triggering batching), then the user + // really should be using a machine-readable output format instead. + w.bar = progress.FromContext(ctx).NewUnitCounter("Preparing output", "rec") + return nil } -// Flush implements output.RecordWriter. +// Flush implements output.RecordWriter. It's a no-op for this writer. func (w *recordWriter) Flush(context.Context) error { + w.mu.Lock() + defer w.mu.Unlock() + return nil } // Close implements output.RecordWriter. func (w *recordWriter) Close(ctx context.Context) error { + w.mu.Lock() + defer w.mu.Unlock() + + if w.bar != nil { + w.bar.Stop() + } if w.rowCount == 0 { // no data to write return nil @@ -46,7 +76,8 @@ func (w *recordWriter) Close(ctx context.Context) error { header := w.recMeta.MungedNames() w.tbl.tblImpl.SetHeader(header) - return w.tbl.renderAll(ctx) + lg.FromContext(ctx).Debug("RecordWriter (text): writing records to output", "recs", w.rowCount) + return w.tbl.writeAll(ctx) } // WriteRecords implements output.RecordWriter. @@ -70,6 +101,7 @@ func (w *recordWriter) WriteRecords(ctx context.Context, recs []record.Record) e tblRows = append(tblRows, tblRow) w.rowCount++ + w.bar.IncrBy(1) } return w.tbl.appendRows(ctx, tblRows) diff --git a/cli/output/tablew/tablew.go b/cli/output/tablew/tablew.go index 237761830..83dcf5b42 100644 --- a/cli/output/tablew/tablew.go +++ b/cli/output/tablew/tablew.go @@ -172,11 +172,11 @@ func (t *table) appendRows(ctx context.Context, rows [][]string) error { return nil } -func (t *table) renderAll(ctx context.Context) error { +func (t *table) writeAll(ctx context.Context) error { return t.tblImpl.RenderAll(ctx) } -func (t *table) renderRow(ctx context.Context, row []string) error { +func (t *table) writeRow(ctx context.Context, row []string) error { t.tblImpl.Append(row) return t.tblImpl.RenderAll(ctx) // Send output } diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index 411a92e1a..abae10c2e 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -107,7 +107,7 @@ func TestSakila_query(t *testing.T) { t.Run(tc.file, func(t *testing.T) { t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := th.Add(&source.Source{ Handle: "@" + tc.file, Type: drvr, @@ -280,7 +280,7 @@ func TestIngest_Kind_Timestamp(t *testing.T) { t.Run(tc.file, func(t *testing.T) { t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := &source.Source{ Handle: "@tsv/" + tc.file, Type: csv.TypeTSV, @@ -345,7 +345,7 @@ func TestIngest_Kind_Date(t *testing.T) { t.Run(tc.file, func(t *testing.T) { t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := &source.Source{ Handle: "@tsv/" + tc.file, Type: csv.TypeTSV, @@ -398,7 +398,7 @@ func TestIngest_Kind_Time(t *testing.T) { t.Run(tc.file, func(t *testing.T) { t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := &source.Source{ Handle: "@tsv/" + tc.file, Type: csv.TypeTSV, diff --git a/drivers/userdriver/userdriver_test.go b/drivers/userdriver/userdriver_test.go index c04d13e37..86328fe8d 100644 --- a/drivers/userdriver/userdriver_test.go +++ b/drivers/userdriver/userdriver_test.go @@ -31,7 +31,7 @@ func TestDriver(t *testing.T) { t.Run(tc.handle, func(t *testing.T) { t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := th.Source(tc.handle) grip := th.Open(src) diff --git a/drivers/xlsx/xlsx_test.go b/drivers/xlsx/xlsx_test.go index 4fe46b835..765d81036 100644 --- a/drivers/xlsx/xlsx_test.go +++ b/drivers/xlsx/xlsx_test.go @@ -55,7 +55,7 @@ func TestSakilaInspectSource(t *testing.T) { tu.SkipWindows(t, "Skipping because of slow workflow perf on windows") tu.SkipShort(t, true) - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := th.Source(sakila.XLSX) tr := testrun.New(th.Context, t, nil).Hush().Add(*src) @@ -74,7 +74,7 @@ func TestSakilaInspectSheets(t *testing.T) { t.Run(sheet, func(t *testing.T) { t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := th.Source(sakila.XLSX) tr := testrun.New(th.Context, t, nil).Hush().Add(*src) @@ -94,7 +94,7 @@ func BenchmarkInspectSheets(b *testing.B) { b.Run(sheet, func(b *testing.B) { for n := 0; n < b.N; n++ { - th := testh.New(b, testh.OptLongOpen()) + th := testh.New(b) src := th.Source(sakila.XLSX) tr := testrun.New(th.Context, b, nil).Hush().Add(*src) @@ -118,7 +118,7 @@ func TestSakila_query_cmd(t *testing.T) { t.Run(sheet, func(t *testing.T) { t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := th.Source(sakila.XLSX) tr := testrun.New(th.Context, t, nil).Hush().Add(*src) @@ -157,7 +157,7 @@ func TestOpenFileFormats(t *testing.T) { t.Run(tc.filename, func(t *testing.T) { t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := th.Add(&source.Source{ Handle: "@excel", Type: xlsx.Type, @@ -240,7 +240,7 @@ func TestSakila_query(t *testing.T) { t.Run(tc.sheet, func(t *testing.T) { t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := th.Source(sakila.XLSX) sink, err := th.QuerySQL(src, nil, "SELECT * FROM "+tc.sheet) @@ -528,7 +528,7 @@ func TestDatetime(t *testing.T) { t.Run(tc.sheet, func(t *testing.T) { t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src = th.Add(src) sink, err := th.QuerySLQ(src.Handle+"."+tc.sheet, nil) diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index e1a503355..2d88a4db0 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -286,11 +286,11 @@ type notifyOnceWriter struct { // Write implements [io.Writer]. On the first invocation of this // method, the notify function is invoked, blocking until it returns. -// Subsequent invocations of Write do trigger the notify function. +// Subsequent invocations of Write don't trigger the notify function. func (w *notifyOnceWriter) Write(p []byte) (n int, err error) { w.notifyOnce.Do(func() { - close(w.doneCh) w.fn() + close(w.doneCh) }) <-w.doneCh diff --git a/libsq/libsq_test.go b/libsq/libsq_test.go index 508d4b753..6cd1709ce 100644 --- a/libsq/libsq_test.go +++ b/libsq/libsq_test.go @@ -63,7 +63,7 @@ func TestQuerySQL_Smoke(t *testing.T) { tu.SkipShort(t, tc.handle == sakila.XLSX) t.Parallel() - th := testh.New(t, testh.OptLongOpen()) + th := testh.New(t) src := th.Source(tc.handle) tblName := sakila.TblActor diff --git a/testh/testh.go b/testh/testh.go index 08d9e7ccc..53afef8bf 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -12,7 +12,6 @@ import ( "strings" "sync" "testing" - "time" "github.com/neilotoole/sq/libsq/core/ioz/lockfile" @@ -59,29 +58,10 @@ import ( "github.com/neilotoole/sq/testh/tu" ) -// defaultDBOpenTimeout is the timeout for tests to open (and ping) their DBs. -// This should be a low value, because, well, we can either connect -// or not. -const defaultDBOpenTimeout = time.Second * 5 //nolint:unused - // Option is a functional option type used with New to // configure the helper. type Option func(h *Helper) -// OptLongOpen allows a longer DB open timeout, which is necessary -// for some tests. Note that DB open performs an import for file-based -// sources, so it can take some time. Usage: -// -// testh.New(t, testh.OptLongOpen()) -// -// Most tests don't need this. -func OptLongOpen() Option { - return func(h *Helper) { - // FIXME: Delete OptLongOpen entirely - // h.dbOpenTimeout = time.Second * 180 - } -} - // OptCaching enables or disables ingest caching. func OptCaching(enable bool) Option { return func(h *Helper) { @@ -117,8 +97,6 @@ type Helper struct { cancelFn context.CancelFunc Cleanup *cleanup.Cleanup - - dbOpenTimeout time.Duration //nolint:unused } // New returns a new Helper. The helper's Close func will be From 9d00e152b8fe6eb2ec8d8a18cc5f88a496ba8849 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 09:08:46 -0700 Subject: [PATCH 187/195] Refactored pkg progress; added mem usage; tablew shows progress --- cli/output/tablew/recordwriter.go | 2 +- libsq/core/progress/bars.go | 74 ++++++++++++++++++++----------- libsq/core/progress/progress.go | 55 +++++++++++++++-------- libsq/core/progress/style.go | 53 +++++++++++++++------- 4 files changed, 123 insertions(+), 61 deletions(-) diff --git a/cli/output/tablew/recordwriter.go b/cli/output/tablew/recordwriter.go index 368dce76d..415773075 100644 --- a/cli/output/tablew/recordwriter.go +++ b/cli/output/tablew/recordwriter.go @@ -46,7 +46,7 @@ func (w *recordWriter) Open(ctx context.Context, recMeta record.Meta) error { // with. After all, this writer is intended for human/interactive use, and // if the number of records is huge (triggering batching), then the user // really should be using a machine-readable output format instead. - w.bar = progress.FromContext(ctx).NewUnitCounter("Preparing output", "rec") + w.bar = progress.FromContext(ctx).NewUnitCounter("Preparing output", "rec", progress.OptMemUsage) return nil } diff --git a/libsq/core/progress/bars.go b/libsq/core/progress/bars.go index 44501ecce..7ba2e0d39 100644 --- a/libsq/core/progress/bars.go +++ b/libsq/core/progress/bars.go @@ -5,14 +5,13 @@ import ( humanize "github.com/dustin/go-humanize" "github.com/dustin/go-humanize/english" - mpb "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" ) // NewByteCounter returns a new determinate bar whose label // metric is the size in bytes of the data being processed. The caller is // ultimately responsible for calling [Bar.Stop] on the returned Bar. -func (p *Progress) NewByteCounter(msg string, size int64) *Bar { +func (p *Progress) NewByteCounter(msg string, size int64, opts ...Opt) *Bar { if p == nil { return nil } @@ -20,21 +19,23 @@ func (p *Progress) NewByteCounter(msg string, size int64) *Bar { p.mu.Lock() defer p.mu.Unlock() - var style mpb.BarFillerBuilder + cfg := &barConfig{msg: msg, total: size} + var counter decor.Decorator var percent decor.Decorator if size < 0 { - style = spinnerStyle(p.colors.Filler) + cfg.style = spinnerStyle(p.colors.Filler) counter = decor.Current(decor.SizeB1024(0), "% .1f") } else { - style = barStyle(p.colors.Filler) + cfg.style = barStyle(p.colors.Filler) counter = decor.Counters(decor.SizeB1024(0), "% .1f / % .1f") percent = decor.NewPercentage(" %.1f", decor.WCSyncSpace) percent = colorize(percent, p.colors.Percent) } counter = colorize(counter, p.colors.Size) + cfg.decorators = []decor.Decorator{counter, percent} - return p.newBar(msg, size, style, counter, percent) + return p.newBar(cfg, opts) } // NewUnitCounter returns a new indeterminate bar whose label @@ -54,7 +55,7 @@ func (p *Progress) NewByteCounter(msg string, size int64) *Bar { // Ingesting records ∙∙● 87 recs // // Note that the unit arg is automatically pluralized. -func (p *Progress) NewUnitCounter(msg, unit string) *Bar { +func (p *Progress) NewUnitCounter(msg, unit string, opts ...Opt) *Bar { if p == nil { return nil } @@ -62,18 +63,22 @@ func (p *Progress) NewUnitCounter(msg, unit string) *Bar { p.mu.Lock() defer p.mu.Unlock() - decorator := decor.Any(func(statistics decor.Statistics) string { + cfg := &barConfig{ + msg: msg, + total: -1, + style: spinnerStyle(p.colors.Filler), + } + + d := decor.Any(func(statistics decor.Statistics) string { s := humanize.Comma(statistics.Current) if unit != "" { s += " " + english.PluralWord(int(statistics.Current), unit, "") } return s }) - decorator = colorize(decorator, p.colors.Size) + cfg.decorators = []decor.Decorator{colorize(d, p.colors.Size)} - style := spinnerStyle(p.colors.Filler) - - return p.newBar(msg, -1, style, decorator) + return p.newBar(cfg, opts) } // NewWaiter returns a generic indeterminate spinner. If arg clock @@ -83,7 +88,7 @@ func (p *Progress) NewUnitCounter(msg, unit string) *Bar { // // The caller is ultimately responsible for calling [Bar.Stop] on the // returned Bar. -func (p *Progress) NewWaiter(msg string, clock bool) *Bar { +func (p *Progress) NewWaiter(msg string, clock bool, opts ...Opt) *Bar { if p == nil { return nil } @@ -91,12 +96,17 @@ func (p *Progress) NewWaiter(msg string, clock bool) *Bar { p.mu.Lock() defer p.mu.Unlock() - var d []decor.Decorator + cfg := &barConfig{ + msg: msg, + total: -1, + style: spinnerStyle(p.colors.Filler), + } + if clock { - d = append(d, newElapsedSeconds(p.colors.Size, time.Now(), decor.WCSyncSpace)) + d := newElapsedSeconds(p.colors.Size, time.Now(), decor.WCSyncSpace) + cfg.decorators = []decor.Decorator{d} } - style := spinnerStyle(p.colors.Filler) - return p.newBar(msg, -1, style, d...) + return p.newBar(cfg, opts) } // NewUnitTotalCounter returns a new determinate bar whose label @@ -108,7 +118,7 @@ func (p *Progress) NewWaiter(msg string, clock bool) *Bar { // Ingesting sheets ∙∙∙∙∙● 4 / 16 sheets // // Note that the unit arg is automatically pluralized. -func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { +func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64, opts ...Opt) *Bar { if p == nil { return nil } @@ -120,16 +130,21 @@ func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { p.mu.Lock() defer p.mu.Unlock() - style := barStyle(p.colors.Filler) - decorator := decor.Any(func(statistics decor.Statistics) string { + cfg := &barConfig{ + msg: msg, + total: total, + style: barStyle(p.colors.Filler), + } + + d := decor.Any(func(statistics decor.Statistics) string { s := humanize.Comma(statistics.Current) + " / " + humanize.Comma(statistics.Total) if unit != "" { s += " " + english.PluralWord(int(statistics.Current), unit, "") } return s }) - decorator = colorize(decorator, p.colors.Size) - return p.newBar(msg, total, style, decorator) + cfg.decorators = []decor.Decorator{colorize(d, p.colors.Size)} + return p.newBar(cfg, opts) } // NewTimeoutWaiter returns a new indeterminate bar whose label is the @@ -140,7 +155,7 @@ func (p *Progress) NewUnitTotalCounter(msg, unit string, total int64) *Bar { // The caller is ultimately responsible for calling [Bar.Stop] on // the returned bar, although the bar will also be stopped when the // parent Progress stops. -func (p *Progress) NewTimeoutWaiter(msg string, expires time.Time) *Bar { +func (p *Progress) NewTimeoutWaiter(msg string, expires time.Time, opts ...Opt) *Bar { if p == nil { return nil } @@ -148,8 +163,12 @@ func (p *Progress) NewTimeoutWaiter(msg string, expires time.Time) *Bar { p.mu.Lock() defer p.mu.Unlock() - style := spinnerStyle(p.colors.Waiting) - decorator := decor.Any(func(statistics decor.Statistics) string { + cfg := &barConfig{ + msg: msg, + style: spinnerStyle(p.colors.Waiting), + } + + d := decor.Any(func(statistics decor.Statistics) string { remaining := time.Until(expires) switch { case remaining > 0: @@ -165,6 +184,7 @@ func (p *Progress) NewTimeoutWaiter(msg string, expires time.Time) *Bar { } }) - total := time.Until(expires) - return p.newBar(msg, int64(total), style, decorator) + cfg.decorators = []decor.Decorator{d} + cfg.total = int64(time.Until(expires)) + return p.newBar(cfg, opts) } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index ff5a4f62b..2e66c1400 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -22,6 +22,8 @@ import ( "sync/atomic" "time" + "github.com/samber/lo" + mpb "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" @@ -220,14 +222,12 @@ func (p *Progress) doStop() { defer lg.FromContext(p.ctx).Debug("Stopped progress widget") if p.pc == nil { close(p.stoppedCh) - // close(p.refreshCh) p.cancelFn() return } if len(p.bars) == 0 { close(p.stoppedCh) - // close(p.refreshCh) p.cancelFn() return } @@ -250,7 +250,6 @@ func (p *Progress) doStop() { p.refreshCh <- time.Now() close(p.stoppedCh) - // close(p.refreshCh) p.pc.Wait() // Important: we must call cancelFn after pc.Wait() or the bars // may not be removed from the terminal. @@ -261,15 +260,28 @@ func (p *Progress) doStop() { <-p.ctx.Done() } +// Opt is a functional option for Bar creation. +type Opt interface { + apply(*Progress, *barConfig) +} + +// barConfig is passed to Progress.newBar. +type barConfig struct { + msg string + total int64 + style mpb.BarFillerBuilder + decorators []decor.Decorator +} + // newBar returns a new Bar. This function must only be called from -// inside the mutex. -func (p *Progress) newBar(msg string, total int64, - style mpb.BarFillerBuilder, decorators ...decor.Decorator, -) *Bar { +// inside the Progress mutex. +func (p *Progress) newBar(cfg *barConfig, opts []Opt) *Bar { if p == nil { return nil } + cfg.decorators = lo.WithoutEmpty(cfg.decorators) + select { case <-p.stoppedCh: return nil @@ -278,22 +290,22 @@ func (p *Progress) newBar(msg string, total int64, default: } - lg.FromContext(p.ctx).Debug("New bar", "msg", msg, "total", total) + lg.FromContext(p.ctx).Debug("New bar", "msg", cfg.msg, "total", cfg.total) if p.pc == nil { p.pcInitFn() } - if total < 0 { - total = 0 + if cfg.total < 0 { + cfg.total = 0 } // We want the bar message to be a consistent width. switch { - case len(msg) < msgLength: - msg += strings.Repeat(" ", msgLength-len(msg)) - case len(msg) > msgLength: - msg = stringz.Ellipsify(msg, msgLength) + case len(cfg.msg) < msgLength: + cfg.msg += strings.Repeat(" ", msgLength-len(cfg.msg)) + case len(cfg.msg) > msgLength: + cfg.msg = stringz.Ellipsify(cfg.msg, msgLength) } b := &Bar{ @@ -317,18 +329,25 @@ func (p *Progress) newBar(msg string, total int64, default: } + for _, opt := range opts { + if opt != nil { + opt.apply(p, cfg) + } + } + // REVISIT: It shouldn't be the case that it's possible that the // progress has already been stopped. If it is stopped, the call // below will panic. Maybe consider wrapping the call in a recover? - b.bar = p.pc.New(total, - style, + b.bar = p.pc.New(cfg.total, + cfg.style, mpb.BarWidth(barWidth), mpb.PrependDecorators( - colorize(decor.Name(msg, decor.WCSyncWidthR), p.colors.Message), + colorize(decor.Name(cfg.msg, decor.WCSyncWidthR), p.colors.Message), ), - mpb.AppendDecorators(decorators...), + mpb.AppendDecorators(cfg.decorators...), mpb.BarRemoveOnComplete(), ) + b.bar.IncrBy(int(b.incrStash.Load())) b.incrStash = nil } diff --git a/libsq/core/progress/style.go b/libsq/core/progress/style.go index 5cd6fbf3b..8b042b1aa 100644 --- a/libsq/core/progress/style.go +++ b/libsq/core/progress/style.go @@ -1,6 +1,8 @@ package progress import ( + "fmt" + "runtime" "time" "github.com/fatih/color" @@ -21,25 +23,27 @@ const ( // DefaultColors returns the default colors used for the progress bars. func DefaultColors() *Colors { return &Colors{ - Error: color.New(color.FgRed, color.Bold), - Filler: color.New(color.FgGreen, color.Bold, color.Faint), - Message: color.New(color.Faint), - Percent: color.New(color.FgCyan, color.Faint), - Size: color.New(color.Faint), - Waiting: color.New(color.FgYellow, color.Faint), - Warning: color.New(color.FgYellow), + Error: color.New(color.FgRed, color.Bold), + Filler: color.New(color.FgGreen, color.Bold, color.Faint), + MemUsage: color.New(color.FgGreen, color.Faint), + Message: color.New(color.Faint), + Percent: color.New(color.FgCyan, color.Faint), + Size: color.New(color.Faint), + Waiting: color.New(color.FgYellow, color.Faint), + Warning: color.New(color.FgYellow), } } // Colors is the set of colors used for the progress bars. type Colors struct { - Error *color.Color - Filler *color.Color - Message *color.Color - Percent *color.Color - Size *color.Color - Waiting *color.Color - Warning *color.Color + Error *color.Color + Filler *color.Color + MemUsage *color.Color + Message *color.Color + Percent *color.Color + Size *color.Color + Waiting *color.Color + Warning *color.Color } // EnableColor enables or disables color for the progress bars. @@ -51,17 +55,18 @@ func (c *Colors) EnableColor(enable bool) { if enable { c.Error.EnableColor() c.Filler.EnableColor() + c.MemUsage.EnableColor() c.Message.EnableColor() c.Percent.EnableColor() c.Size.EnableColor() c.Waiting.EnableColor() c.Warning.EnableColor() - return } c.Error.DisableColor() c.Filler.DisableColor() + c.MemUsage.DisableColor() c.Message.DisableColor() c.Percent.DisableColor() c.Size.DisableColor() @@ -114,3 +119,21 @@ func newElapsedSeconds(c *color.Color, startTime time.Time, wcc ...decor.WC) dec } return decor.Any(fn, wcc...) } + +// OptMemUsage is an Opt that causes the bar to display program +// memory usage. +var OptMemUsage = optMemUsage{} + +var _ Opt = optMemUsage{} + +type optMemUsage struct{} + +func (optMemUsage) apply(p *Progress, cfg *barConfig) { + fn := func(s decor.Statistics) string { + stats := &runtime.MemStats{} + runtime.ReadMemStats(stats) + msg := fmt.Sprintf(" (% .1f)", decor.SizeB1024(stats.Sys)) + return p.colors.MemUsage.Sprint(msg) + } + cfg.decorators = append(cfg.decorators, decor.Any(fn, decor.WCSyncSpace)) +} From d10d038850d676140f949bbde4fd44031f196791 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 11:14:42 -0700 Subject: [PATCH 188/195] Diff cleanup --- .golangci.yml | 2 +- cli/diff/data_naive.go | 31 +++++------- cli/diff/diff.go | 64 +++++++++++++++++++++++- cli/diff/diff_test.go | 3 +- cli/diff/internal/go-udiff/myers/diff.go | 7 +-- cli/diff/record.go | 19 +++---- cli/diff/source.go | 40 +++++---------- cli/diff/table.go | 27 +++++----- cli/output/format/opt.go | 2 +- libsq/core/options/opt.go | 4 +- libsq/core/progress/style.go | 2 +- libsq/driver/record.go | 2 +- 12 files changed, 118 insertions(+), 85 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 3d6fcc2c9..c569e3ca5 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -205,7 +205,7 @@ linters-settings: # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#argument-limit - name: argument-limit disabled: false - arguments: [ 6 ] + arguments: [ 7 ] # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#atomic - name: atomic disabled: false diff --git a/cli/diff/data_naive.go b/cli/diff/data_naive.go index 3c2f8ef56..324010145 100644 --- a/cli/diff/data_naive.go +++ b/cli/diff/data_naive.go @@ -7,11 +7,11 @@ import ( "slices" "time" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/samber/lo" "golang.org/x/sync/errgroup" - udiff "github.com/neilotoole/sq/cli/diff/internal/go-udiff" - "github.com/neilotoole/sq/cli/diff/internal/go-udiff/myers" "github.com/neilotoole/sq/cli/output" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq" @@ -34,6 +34,7 @@ import ( // raw record.Record values, and only generate the diff text if there // are differences, and even then, to only selectively generate the // needed text. +// See: https://github.com/neilotoole/sq/issues/353. func buildTableDataDiff(ctx context.Context, ru *run.Run, cfg *Config, td1, td2 *tableData, ) (*tableDataDiff, error) { @@ -49,6 +50,7 @@ func buildTableDataDiff(ctx context.Context, ru *run.Run, cfg *Config, w1, w2 := cfg.RecordWriterFn(buf1, pr), cfg.RecordWriterFn(buf2, pr) recw1, recw2 := output.NewRecordWriterAdapter(ctx, w1), output.NewRecordWriterAdapter(ctx, w2) + bar := progress.FromContext(ctx).NewWaiter("Retrieving diff data", true, progress.OptMemUsage) g, gCtx := errgroup.WithContext(ctx) g.Go(func() error { if err := libsq.ExecuteSLQ(gCtx, qc, query1, recw1); err != nil { @@ -73,25 +75,18 @@ func buildTableDataDiff(ctx context.Context, ru *run.Run, cfg *Config, _, err := recw2.Wait() return err }) - if err := g.Wait(); err != nil { + err := g.Wait() + bar.Stop() + if err != nil { return nil, err } - var ( - body1, body2 = buf1.String(), buf2.String() - err error - ) - - edits := myers.ComputeEdits(body1, body2) - unified, err := udiff.ToUnified( - query1, - query2, - body1, - edits, - cfg.Lines, - ) + body1, body2 := buf1.String(), buf2.String() + + msg := fmt.Sprintf("table {%s}", td1.tblName) + unified, err := computeUnified(ctx, msg, query1, query2, cfg.Lines, body1, body2) if err != nil { - return nil, errz.Err(err) + return nil, err } return &tableDataDiff{ @@ -182,7 +177,7 @@ func execSourceDataDiff(ctx context.Context, ru *run.Run, cfg *Config, sd1, sd2 } tblDataDiff = diffs[printIndex] - if err := Print(ru.Out, ru.Writers.Printing, tblDataDiff.header, tblDataDiff.diff); err != nil { + if err := Print(ctx, ru.Out, ru.Writers.Printing, tblDataDiff.header, tblDataDiff.diff); err != nil { printErrCh <- err return } diff --git a/cli/diff/diff.go b/cli/diff/diff.go index 1872757bf..291e95f34 100644 --- a/cli/diff/diff.go +++ b/cli/diff/diff.go @@ -7,10 +7,15 @@ package diff import ( + "context" "fmt" "io" "strings" + udiff "github.com/neilotoole/sq/cli/diff/internal/go-udiff" + "github.com/neilotoole/sq/cli/diff/internal/go-udiff/myers" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/cli/output" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/stringz" @@ -117,7 +122,7 @@ type tableDataDiff struct { } // Print prints dif to w. If pr is nil, printing is in monochrome. -func Print(w io.Writer, pr *output.Printing, header, dif string) error { +func Print(ctx context.Context, w io.Writer, pr *output.Printing, header, dif string) error { if dif == "" { return nil } @@ -130,6 +135,9 @@ func Print(w io.Writer, pr *output.Printing, header, dif string) error { return errz.Err(err) } + bar := progress.FromContext(ctx). + NewUnitCounter("Preparing diff output", "line", progress.OptMemUsage) + after := stringz.VisitLines(dif, func(i int, line string) string { if i == 0 && strings.HasPrefix(line, "---") { return pr.DiffHeader.Sprint(line) @@ -150,6 +158,7 @@ func Print(w io.Writer, pr *output.Printing, header, dif string) error { return pr.DiffPlus.Sprint(line) } + bar.IncrBy(1) return pr.DiffNormal.Sprint(line) }) @@ -157,6 +166,59 @@ func Print(w io.Writer, pr *output.Printing, header, dif string) error { after = pr.DiffHeader.Sprint(header) + "\n" + after } + bar.Stop() _, err := fmt.Fprintln(w, after) return errz.Err(err) } + +// computeUnified encapsulates computing a unified diff. +func computeUnified(ctx context.Context, msg, oldLabel, newLabel string, lines int, + before, after string, +) (string, error) { + if msg == "" { + msg = "Diffing" + } else { + msg = fmt.Sprintf("Diffing (%s)", msg) + } + + bar := progress.FromContext(ctx).NewWaiter(msg, true, progress.OptMemUsage) + defer bar.Stop() + + var ( + unified string + err error + done = make(chan struct{}) + ) + + // We compute the diff on a goroutine because the underlying diff + // library functions aren't context-aware. + go func() { + defer close(done) + + edits := myers.ComputeEdits(before, after) + // After edits are computed, if the context is done, + // there's no point continuing. + select { + case <-ctx.Done(): + err = errz.Err(ctx.Err()) + return + default: + } + + unified, err = udiff.ToUnified( + oldLabel, + newLabel, + before, + edits, + lines, + ) + }() + + select { + case <-ctx.Done(): + return "", errz.Err(ctx.Err()) + case <-done: + } + + return unified, err +} diff --git a/cli/diff/diff_test.go b/cli/diff/diff_test.go index caeef90e6..c268260b2 100644 --- a/cli/diff/diff_test.go +++ b/cli/diff/diff_test.go @@ -2,6 +2,7 @@ package diff import ( "bytes" + "context" "fmt" "testing" @@ -37,7 +38,7 @@ func TestMyers(t *testing.T) { require.NoError(t, err) buf := &bytes.Buffer{} - err = Print(buf, output.NewPrinting(), "diff before after", result) + err = Print(context.Background(), buf, output.NewPrinting(), "diff before after", result) require.NoError(t, err) t.Logf("\n" + buf.String()) diff --git a/cli/diff/internal/go-udiff/myers/diff.go b/cli/diff/internal/go-udiff/myers/diff.go index 028e63a0e..5ec254461 100644 --- a/cli/diff/internal/go-udiff/myers/diff.go +++ b/cli/diff/internal/go-udiff/myers/diff.go @@ -11,10 +11,11 @@ import ( diff "github.com/neilotoole/sq/cli/diff/internal/go-udiff" ) -// Grips: -// https://blog.jcoglan.com/2017/02/17/the-myers-diff-algorithm-part-3/ -// https://www.codeproject.com/Articles/42279/%2FArticles%2F42279%2FInvestigating-Myers-diff-algorithm-Part-1-of-2 +// Reference: +// - https://blog.jcoglan.com/2017/02/17/the-myers-diff-algorithm-part-3/ +// - https://www.codeproject.com/Articles/42279/%2FArticles%2F42279%2FInvestigating-Myers-diff-algorithm-Part-1-of-2 +// ComputeEdits computes the diff edits. func ComputeEdits(before, after string) []diff.Edit { beforeLines := splitLines(before) ops := operations(beforeLines, splitLines(after)) diff --git a/cli/diff/record.go b/cli/diff/record.go index 1a88ee5c1..6737630ba 100644 --- a/cli/diff/record.go +++ b/cli/diff/record.go @@ -5,8 +5,6 @@ import ( "context" "fmt" - udiff "github.com/neilotoole/sq/cli/diff/internal/go-udiff" - "github.com/neilotoole/sq/cli/diff/internal/go-udiff/myers" "github.com/neilotoole/sq/cli/output" "github.com/neilotoole/sq/cli/output/yamlw" "github.com/neilotoole/sq/cli/run" @@ -41,6 +39,8 @@ type recordDiff struct { // possibly better path. The code is left here as a guilty reminder // to tackle this issue. // +// See:https://github.com/neilotoole/sq/issues/353 +// //nolint:unused func findRecordDiff(ctx context.Context, ru *run.Run, lines int, td1, td2 *tableData, @@ -185,20 +185,13 @@ func populateRecordDiff(ctx context.Context, lines int, pr *output.Printing, rec return err } - edits := myers.ComputeEdits(body1, body2) - recDiff.diff, err = udiff.ToUnified( - handleTbl1, - handleTbl2, - body1, - edits, - lines, - ) + msg := fmt.Sprintf("table {%s}", recDiff.td1.tblName) + recDiff.diff, err = computeUnified(ctx, msg, handleTbl1, handleTbl2, lines, body1, body2) if err != nil { - return errz.Err(err) + return err } - recDiff.header = fmt.Sprintf("sq diff %s %s | .[%d]", - handleTbl1, handleTbl2, recDiff.row) + recDiff.header = fmt.Sprintf("sq diff %s %s | .[%d]", handleTbl1, handleTbl2, recDiff.row) return nil } diff --git a/cli/diff/source.go b/cli/diff/source.go index f05a5d25b..6b52530bf 100644 --- a/cli/diff/source.go +++ b/cli/diff/source.go @@ -8,8 +8,6 @@ import ( "github.com/samber/lo" "golang.org/x/sync/errgroup" - udiff "github.com/neilotoole/sq/cli/diff/internal/go-udiff" - "github.com/neilotoole/sq/cli/diff/internal/go-udiff/myers" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/source" @@ -44,22 +42,22 @@ func ExecSourceDiff(ctx context.Context, ru *run.Run, cfg *Config, } if elems.Overview { - srcDiff, err := buildSourceOverviewDiff(cfg, sd1, sd2) + srcDiff, err := buildSourceOverviewDiff(ctx, cfg, sd1, sd2) if err != nil { return err } - if err = Print(ru.Out, ru.Writers.Printing, srcDiff.header, srcDiff.diff); err != nil { + if err = Print(ctx, ru.Out, ru.Writers.Printing, srcDiff.header, srcDiff.diff); err != nil { return err } } if elems.DBProperties { - propsDiff, err := buildDBPropsDiff(cfg, sd1, sd2) + propsDiff, err := buildDBPropsDiff(ctx, cfg, sd1, sd2) if err != nil { return err } - if err = Print(ru.Out, ru.Writers.Printing, propsDiff.header, propsDiff.diff); err != nil { + if err = Print(ctx, ru.Out, ru.Writers.Printing, propsDiff.header, propsDiff.diff); err != nil { return err } } @@ -70,7 +68,7 @@ func ExecSourceDiff(ctx context.Context, ru *run.Run, cfg *Config, return err } for _, tblDiff := range tblDiffs { - if err := Print(ru.Out, ru.Writers.Printing, tblDiff.header, tblDiff.diff); err != nil { + if err = Print(ctx, ru.Out, ru.Writers.Printing, tblDiff.header, tblDiff.diff); err != nil { return err } } @@ -112,7 +110,7 @@ func buildSourceTableDiffs(ctx context.Context, cfg *Config, showRowCounts bool, srcMeta: sd2.srcMeta, } - dff, err := buildTableStructureDiff(cfg, showRowCounts, td1, td2) + dff, err := buildTableStructureDiff(ctx, cfg, showRowCounts, td1, td2) if err != nil { return nil, err } @@ -123,7 +121,7 @@ func buildSourceTableDiffs(ctx context.Context, cfg *Config, showRowCounts bool, return diffs, nil } -func buildSourceOverviewDiff(cfg *Config, sd1, sd2 *sourceData) (*sourceOverviewDiff, error) { +func buildSourceOverviewDiff(ctx context.Context, cfg *Config, sd1, sd2 *sourceData) (*sourceOverviewDiff, error) { var ( body1, body2 string err error @@ -136,16 +134,9 @@ func buildSourceOverviewDiff(cfg *Config, sd1, sd2 *sourceData) (*sourceOverview return nil, err } - edits := myers.ComputeEdits(body1, body2) - unified, err := udiff.ToUnified( - sd1.handle, - sd2.handle, - body1, - edits, - cfg.Lines, - ) + unified, err := computeUnified(ctx, "overview", sd1.handle, sd2.handle, cfg.Lines, body1, body2) if err != nil { - return nil, errz.Err(err) + return nil, err } diff := &sourceOverviewDiff{ @@ -158,7 +149,7 @@ func buildSourceOverviewDiff(cfg *Config, sd1, sd2 *sourceData) (*sourceOverview return diff, nil } -func buildDBPropsDiff(cfg *Config, sd1, sd2 *sourceData) (*dbPropsDiff, error) { +func buildDBPropsDiff(ctx context.Context, cfg *Config, sd1, sd2 *sourceData) (*dbPropsDiff, error) { var ( body1, body2 string err error @@ -171,16 +162,9 @@ func buildDBPropsDiff(cfg *Config, sd1, sd2 *sourceData) (*dbPropsDiff, error) { return nil, err } - edits := myers.ComputeEdits(body1, body2) - unified, err := udiff.ToUnified( - sd1.handle, - sd2.handle, - body1, - edits, - cfg.Lines, - ) + unified, err := computeUnified(ctx, "dbprops", sd1.handle, sd2.handle, cfg.Lines, body1, body2) if err != nil { - return nil, errz.Err(err) + return nil, err } return &dbPropsDiff{ diff --git a/cli/diff/table.go b/cli/diff/table.go index 57aa2447c..50dbbfc55 100644 --- a/cli/diff/table.go +++ b/cli/diff/table.go @@ -6,8 +6,6 @@ import ( "golang.org/x/sync/errgroup" - udiff "github.com/neilotoole/sq/cli/diff/internal/go-udiff" - "github.com/neilotoole/sq/cli/diff/internal/go-udiff/myers" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/driver" @@ -48,12 +46,12 @@ func ExecTableDiff(ctx context.Context, ru *run.Run, cfg *Config, elems *Element } var tblDiff *tableDiff - tblDiff, err = buildTableStructureDiff(cfg, elems.RowCount, td1, td2) + tblDiff, err = buildTableStructureDiff(ctx, cfg, elems.RowCount, td1, td2) if err != nil { return err } - if err = Print(ru.Out, ru.Writers.Printing, tblDiff.header, tblDiff.diff); err != nil { + if err = Print(ctx, ru.Out, ru.Writers.Printing, tblDiff.header, tblDiff.diff); err != nil { return err } } @@ -71,10 +69,12 @@ func ExecTableDiff(ctx context.Context, ru *run.Run, cfg *Config, elems *Element return nil } - return Print(ru.Out, ru.Writers.Printing, tblDataDiff.header, tblDataDiff.diff) + return Print(ctx, ru.Out, ru.Writers.Printing, tblDataDiff.header, tblDataDiff.diff) } -func buildTableStructureDiff(cfg *Config, showRowCounts bool, td1, td2 *tableData) (*tableDiff, error) { +func buildTableStructureDiff(ctx context.Context, cfg *Config, showRowCounts bool, + td1, td2 *tableData, +) (*tableDiff, error) { var ( body1, body2 string err error @@ -87,16 +87,13 @@ func buildTableStructureDiff(cfg *Config, showRowCounts bool, td1, td2 *tableDat return nil, err } - edits := myers.ComputeEdits(body1, body2) - unified, err := udiff.ToUnified( - td1.src.Handle+"."+td1.tblName, - td2.src.Handle+"."+td2.tblName, - body1, - edits, - cfg.Lines, - ) + handle1 := td1.src.Handle + "." + td1.tblName + handle2 := td2.src.Handle + "." + td2.tblName + + msg := fmt.Sprintf("table schema {%s}", td1.tblName) + unified, err := computeUnified(ctx, msg, handle1, handle2, cfg.Lines, body1, body2) if err != nil { - return nil, errz.Err(err) + return nil, err } tblDiff := &tableDiff{ diff --git a/cli/output/format/opt.go b/cli/output/format/opt.go index b7d6765de..9161744e2 100644 --- a/cli/output/format/opt.go +++ b/cli/output/format/opt.go @@ -9,7 +9,7 @@ var _ options.Opt = Opt{} // NewOpt returns a new format.Opt instance. If validFn is non-nil, it // is executed against possible values. -func NewOpt(key, flag string, short rune, defaultVal Format, //nolint:revive +func NewOpt(key, flag string, short rune, defaultVal Format, validFn func(Format) error, usage, help string, ) Opt { opt := options.NewBaseOpt(key, flag, short, usage, help, options.TagOutput) diff --git a/libsq/core/options/opt.go b/libsq/core/options/opt.go index 552ad67da..a7c3148ff 100644 --- a/libsq/core/options/opt.go +++ b/libsq/core/options/opt.go @@ -277,7 +277,7 @@ var _ Opt = Int{} // NewInt returns an options.Int instance. If flag is empty, the // value of key is used. -func NewInt(key, flag string, short rune, defaultVal int, usage, help string, tags ...string) Int { //nolint:revive +func NewInt(key, flag string, short rune, defaultVal int, usage, help string, tags ...string) Int { return Int{ BaseOpt: NewBaseOpt(key, flag, short, usage, help, tags...), defaultVal: defaultVal, @@ -521,7 +521,7 @@ var _ Opt = Duration{} // NewDuration returns an options.Duration instance. If flag is empty, the // value of key is used. -func NewDuration(key, flag string, short rune, defaultVal time.Duration, //nolint:revive +func NewDuration(key, flag string, short rune, defaultVal time.Duration, usage, help string, tags ...string, ) Duration { return Duration{ diff --git a/libsq/core/progress/style.go b/libsq/core/progress/style.go index 8b042b1aa..5b48061eb 100644 --- a/libsq/core/progress/style.go +++ b/libsq/core/progress/style.go @@ -132,7 +132,7 @@ func (optMemUsage) apply(p *Progress, cfg *barConfig) { fn := func(s decor.Statistics) string { stats := &runtime.MemStats{} runtime.ReadMemStats(stats) - msg := fmt.Sprintf(" (% .1f)", decor.SizeB1024(stats.Sys)) + msg := fmt.Sprintf(" (% .1f)", decor.SizeB1024(stats.Sys)) return p.colors.MemUsage.Sprint(msg) } cfg.decorators = append(cfg.decorators, decor.Any(fn, decor.WCSyncSpace)) diff --git a/libsq/driver/record.go b/libsq/driver/record.go index 08c511b5b..4af2e58cf 100644 --- a/libsq/driver/record.go +++ b/libsq/driver/record.go @@ -382,7 +382,7 @@ func (bi *BatchInsert) Munge(rec []any) error { // it must be a sql.Conn or sql.Tx. // //nolint:gocognit -func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, //nolint:revive +func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, destTbl string, destColNames []string, batchSize int, ) (*BatchInsert, error) { log := lg.FromContext(ctx) From e4935364c43a39786eb647d21472902367070b57 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 12:54:29 -0700 Subject: [PATCH 189/195] Implemented sqlite metadata progress update (awkwardly) --- cli/diff/diff.go | 2 +- cli/output/tablew/recordwriter.go | 2 +- drivers/sqlite3/grip.go | 131 ++++++++++++++++++++++++++++++ drivers/sqlite3/metadata.go | 53 ++++++++---- drivers/sqlite3/metadata_test.go | 10 ++- drivers/sqlite3/pragma.go | 11 ++- drivers/sqlite3/sqlite3.go | 110 +------------------------ drivers/xlsx/ingest.go | 4 +- libsq/core/progress/bars.go | 2 +- libsq/core/progress/io.go | 4 +- libsq/core/progress/progress.go | 43 ++++++++-- libsq/driver/record.go | 4 +- 12 files changed, 229 insertions(+), 147 deletions(-) create mode 100644 drivers/sqlite3/grip.go diff --git a/cli/diff/diff.go b/cli/diff/diff.go index 291e95f34..92e26f585 100644 --- a/cli/diff/diff.go +++ b/cli/diff/diff.go @@ -158,7 +158,7 @@ func Print(ctx context.Context, w io.Writer, pr *output.Printing, header, dif st return pr.DiffPlus.Sprint(line) } - bar.IncrBy(1) + bar.Incr(1) return pr.DiffNormal.Sprint(line) }) diff --git a/cli/output/tablew/recordwriter.go b/cli/output/tablew/recordwriter.go index 415773075..d2f3c0736 100644 --- a/cli/output/tablew/recordwriter.go +++ b/cli/output/tablew/recordwriter.go @@ -101,7 +101,7 @@ func (w *recordWriter) WriteRecords(ctx context.Context, recs []record.Record) e tblRows = append(tblRows, tblRow) w.rowCount++ - w.bar.IncrBy(1) + w.bar.Incr(1) } return w.tbl.appendRows(ctx, tblRows) diff --git a/drivers/sqlite3/grip.go b/drivers/sqlite3/grip.go new file mode 100644 index 000000000..497904358 --- /dev/null +++ b/drivers/sqlite3/grip.go @@ -0,0 +1,131 @@ +package sqlite3 + +import ( + "context" + "database/sql" + "log/slog" + "os" + "sync" + + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/core/sqlz" + "github.com/neilotoole/sq/libsq/driver" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/metadata" +) + +// grip implements driver.Grip. +type grip struct { + log *slog.Logger + db *sql.DB + src *source.Source + drvr *driveri + + // closeOnce and closeErr are used to ensure that Close is only called once. + // This is particularly relevant to sqlite, as calling Close multiple times + // can cause problems on Windows. + closeOnce sync.Once + closeErr error +} + +// DB implements driver.Grip. +func (g *grip) DB(context.Context) (*sql.DB, error) { + return g.db, nil +} + +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.drvr +} + +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src +} + +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { + db, err := g.DB(ctx) + if err != nil { + return nil, err + } + + bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+"."+tblName+": reading schema", "row") + defer bar.Stop() + + return getTableMetadata(ctx, db, bar.Incr, tblName) +} + +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + // https://stackoverflow.com/questions/9646353/how-to-find-sqlite-database-file-version + + bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+": reading schema", "row") + defer bar.Stop() + + md := &metadata.Source{Handle: g.src.Handle, Driver: Type, DBDriver: dbDrvr} + + dsn, err := PathFromLocation(g.src) + if err != nil { + return nil, err + } + + const q = "SELECT sqlite_version(), (SELECT name FROM pragma_database_list ORDER BY seq limit 1);" + + err = g.db.QueryRowContext(ctx, q).Scan(&md.DBVersion, &md.Schema) + if err != nil { + return nil, errw(err) + } + bar.Incr(1) + + md.DBProduct = "SQLite3 v" + md.DBVersion + + fi, err := os.Stat(dsn) + if err != nil { + return nil, errw(err) + } + + md.Size = fi.Size() + md.Name = fi.Name() + md.FQName = fi.Name() + "." + md.Schema + // SQLite doesn't support catalog, but we conventionally set it to "default" + md.Catalog = "default" + md.Location = g.src.Location + + md.DBProperties, err = getDBProperties(ctx, g.db, bar.Incr) + if err != nil { + return nil, err + } + + if noSchema { + return md, nil + } + + md.Tables, err = getAllTableMetadata(ctx, g.db, bar.Incr, md.Schema) + if err != nil { + return nil, err + } + + for _, tbl := range md.Tables { + if tbl.TableType == sqlz.TableTypeTable { + md.TableCount++ + } else if tbl.TableType == sqlz.TableTypeView { + md.ViewCount++ + } + } + + return md, nil +} + +// Close implements driver.Grip. Subsequent calls to Close are no-op and +// return the same error. +func (g *grip) Close() error { + g.closeOnce.Do(func() { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) + g.closeErr = errw(g.db.Close()) + }) + + return g.closeErr +} diff --git a/drivers/sqlite3/metadata.go b/drivers/sqlite3/metadata.go index a518f91e9..ccb895cfb 100644 --- a/drivers/sqlite3/metadata.go +++ b/drivers/sqlite3/metadata.go @@ -9,6 +9,8 @@ import ( "reflect" "strings" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" @@ -20,6 +22,11 @@ import ( "github.com/neilotoole/sq/libsq/source/metadata" ) +// IncrFunc is a function that increments a counter. +// It is used to gather statistics, in particular to update +// progress bars. +type IncrFunc func(int) + // recordMetaFromColumnTypes returns record.Meta for colTypes. func recordMetaFromColumnTypes(ctx context.Context, colTypes []*sql.ColumnType, ) (record.Meta, error) { @@ -262,7 +269,7 @@ func DBTypeForKind(knd kind.Kind) string { } // getTableMetadata returns metadata for tblName in db. -func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*metadata.Table, error) { +func getTableMetadata(ctx context.Context, db sqlz.DB, incr IncrFunc, tblName string) (*metadata.Table, error) { log := lg.FromContext(ctx) tblMeta := &metadata.Table{Name: tblName} // Note that there's no easy way of getting the physical size of @@ -282,6 +289,7 @@ func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*metadat if err != nil { return nil, errw(err) } + incr(1) switch { case isVirtualTbl.Valid && isVirtualTbl.Bool: @@ -307,6 +315,9 @@ func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*metadat defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows) for rows.Next() { + incr(1) + progress.DebugDelay() + col := &metadata.Column{} var notnull int64 defaultValue := &sql.NullString{} @@ -324,6 +335,7 @@ func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*metadat if col.BaseType, err = getTypeOfColumn(ctx, db, tblMeta.Name, col.Name); err != nil { return nil, err } + incr(1) } col.PrimaryKey = pkValue.Int64 > 0 // pkVal can be 0,1,2 etc @@ -347,7 +359,8 @@ func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*metadat // non-system tables in db's schema. Arg schemaName is used to // set Table.FQName; it is not used to select which schema // to introspect. -func getAllTableMetadata(ctx context.Context, db sqlz.DB, schemaName string) ([]*metadata.Table, error) { +// The supplied incr func should be invoked for each row read from the DB. +func getAllTableMetadata(ctx context.Context, db sqlz.DB, incr IncrFunc, schemaName string) ([]*metadata.Table, error) { log := lg.FromContext(ctx) // This query returns a row for each column of each table, // order by table name then col id (ordinal). @@ -372,12 +385,14 @@ FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p ORDER BY m.name, p.cid ` - var tblMetas []*metadata.Table - var tblNames []string - var curTblName string - var curTblType string - var curTblIsVirtual bool - var curTblMeta *metadata.Table + var ( + tblMetas []*metadata.Table + tblNames []string + curTblName string + curTblType string + curTblIsVirtual bool + curTblMeta *metadata.Table + ) rows, err := db.QueryContext(ctx, query) if err != nil { @@ -386,6 +401,8 @@ ORDER BY m.name, p.cid defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows) for rows.Next() { + incr(1) + progress.DebugDelay() select { case <-ctx.Done(): return nil, ctx.Err() @@ -425,6 +442,7 @@ ORDER BY m.name, p.cid if col.BaseType, err = getTypeOfColumn(ctx, db, curTblName, col.Name); err != nil { return nil, err } + incr(1) } if curTblMeta == nil || curTblMeta.Name != curTblName { @@ -466,7 +484,7 @@ ORDER BY m.name, p.cid // Separately, we need to get the row counts for the tables var rowCounts []int64 - rowCounts, err = getTblRowCounts(ctx, db, tblNames) + rowCounts, err = getTblRowCounts(ctx, db, incr, tblNames) if err != nil { return nil, errw(err) } @@ -479,7 +497,7 @@ ORDER BY m.name, p.cid } // getTblRowCounts returns the number of rows in each table. -func getTblRowCounts(ctx context.Context, db sqlz.DB, tblNames []string) ([]int64, error) { +func getTblRowCounts(ctx context.Context, db sqlz.DB, incr IncrFunc, tblNames []string) ([]int64, error) { log := lg.FromContext(ctx) // See: https://stackoverflow.com/questions/7524612/how-to-count-rows-from-multiple-tables-in-sqlite @@ -505,12 +523,13 @@ func getTblRowCounts(ctx context.Context, db sqlz.DB, tblNames []string) ([]int6 // Thus if len(tblNames) > 500, we need to execute multiple queries. const maxCompoundSelect = 500 - tblCounts := make([]int64, len(tblNames)) - - var sb strings.Builder - var query string - var terms int - var j int + var ( + tblCounts = make([]int64, len(tblNames)) + sb strings.Builder + query string + terms int + j int + ) for i := 0; i < len(tblNames); i++ { if terms > 0 { @@ -537,6 +556,8 @@ func getTblRowCounts(ctx context.Context, db sqlz.DB, tblNames []string) ([]int6 return nil, errw(err) } j++ + incr(1) + progress.DebugDelay() } if err = rows.Err(); err != nil { diff --git a/drivers/sqlite3/metadata_test.go b/drivers/sqlite3/metadata_test.go index 209bd121b..9fc691079 100644 --- a/drivers/sqlite3/metadata_test.go +++ b/drivers/sqlite3/metadata_test.go @@ -304,7 +304,7 @@ func TestGetTblRowCounts(t *testing.T) { tblNames := createTypeTestTbls(th, src, numTables, true) - counts, err := sqlite3.GetTblRowCounts(th.Context, db, tblNames) + counts, err := sqlite3.GetTblRowCounts(th.Context, db, func(i int) {}, tblNames) require.NoError(t, err) require.Equal(t, len(tblNames), len(counts)) } @@ -318,7 +318,7 @@ func BenchmarkGetTblRowCounts(b *testing.B) { testCases := []struct { name string - fn func(ctx context.Context, db sqlz.DB, tblNames []string) ([]int64, error) + fn func(ctx context.Context, db sqlz.DB, incr sqlite3.IncrFunc, tblNames []string) ([]int64, error) }{ {name: "benchGetTblRowCountsBaseline", fn: benchGetTblRowCountsBaseline}, {name: "getTblRowCounts", fn: sqlite3.GetTblRowCounts}, @@ -329,7 +329,7 @@ func BenchmarkGetTblRowCounts(b *testing.B) { b.Run(tc.name, func(b *testing.B) { for n := 0; n < b.N; n++ { - counts, err := tc.fn(th.Context, db, tblNames) + counts, err := tc.fn(th.Context, db, func(int) {}, tblNames) require.NoError(b, err) require.Len(b, counts, len(tblNames)) } @@ -343,7 +343,9 @@ func BenchmarkGetTblRowCounts(b *testing.B) { // benchGetTblRowCountsBaseline is a baseline impl of getTblRowCounts // for benchmark comparison. -func benchGetTblRowCountsBaseline(ctx context.Context, db sqlz.DB, tblNames []string) ([]int64, error) { +func benchGetTblRowCountsBaseline(ctx context.Context, db sqlz.DB, + _ sqlite3.IncrFunc, tblNames []string, +) ([]int64, error) { tblCounts := make([]int64, len(tblNames)) for i := range tblNames { diff --git a/drivers/sqlite3/pragma.go b/drivers/sqlite3/pragma.go index 1007fe646..0302b75f8 100644 --- a/drivers/sqlite3/pragma.go +++ b/drivers/sqlite3/pragma.go @@ -6,6 +6,8 @@ import ( "fmt" "strings" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lgm" @@ -13,9 +15,11 @@ import ( ) // getDBProperties returns a map of the DB's settings, as exposed -// via SQLite's pragma mechanism. +// via SQLite's pragma mechanism. The supplied incr func should +// be invoked for each row read from the DB. +// // See: https://www.sqlite.org/pragma.html -func getDBProperties(ctx context.Context, db sqlz.DB) (map[string]any, error) { +func getDBProperties(ctx context.Context, db sqlz.DB, incr IncrFunc) (map[string]any, error) { pragmas, err := listPragmaNames(ctx, db) if err != nil { return nil, err @@ -29,6 +33,9 @@ func getDBProperties(ctx context.Context, db sqlz.DB) (map[string]any, error) { return nil, errz.Wrapf(errw(err), "read pragma: %s", pragma) } + incr(1) + progress.DebugDelay() + if val != nil { m[pragma] = val } diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index 17a70a2a4..a9456eed5 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -12,7 +12,6 @@ import ( "os" "path/filepath" "strings" - "sync" "time" "github.com/neilotoole/sq/libsq/core/ioz" @@ -113,7 +112,7 @@ func (d *driveri) ErrWrapFunc() func(error) error { // DBProperties implements driver.SQLDriver. func (d *driveri) DBProperties(ctx context.Context, db sqlz.DB) (map[string]any, error) { - return getDBProperties(ctx, db) + return getDBProperties(ctx, db, nil) } // DriverMetadata implements driver.Driver. @@ -913,113 +912,6 @@ func (d *driveri) getTableRecordMeta(ctx context.Context, db sqlz.DB, tblName st return destCols, nil } -// grip implements driver.Grip. -type grip struct { - log *slog.Logger - db *sql.DB - src *source.Source - drvr *driveri - - // closeOnce and closeErr are used to ensure that Close is only called once. - // This is particularly relevant to sqlite, as calling Close multiple times - // can cause problems on Windows. - closeOnce sync.Once - closeErr error -} - -// DB implements driver.Grip. -func (g *grip) DB(context.Context) (*sql.DB, error) { - return g.db, nil -} - -// SQLDriver implements driver.Grip. -func (g *grip) SQLDriver() driver.SQLDriver { - return g.drvr -} - -// Source implements driver.Grip. -func (g *grip) Source() *source.Source { - return g.src -} - -// TableMetadata implements driver.Grip. -func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - db, err := g.DB(ctx) - if err != nil { - return nil, err - } - - return getTableMetadata(ctx, db, tblName) -} - -// SourceMetadata implements driver.Grip. -func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - // https://stackoverflow.com/questions/9646353/how-to-find-sqlite-database-file-version - - md := &metadata.Source{Handle: g.src.Handle, Driver: Type, DBDriver: dbDrvr} - - dsn, err := PathFromLocation(g.src) - if err != nil { - return nil, err - } - - const q = "SELECT sqlite_version(), (SELECT name FROM pragma_database_list ORDER BY seq limit 1);" - - err = g.db.QueryRowContext(ctx, q).Scan(&md.DBVersion, &md.Schema) - if err != nil { - return nil, errw(err) - } - - md.DBProduct = "SQLite3 v" + md.DBVersion - - fi, err := os.Stat(dsn) - if err != nil { - return nil, errw(err) - } - - md.Size = fi.Size() - md.Name = fi.Name() - md.FQName = fi.Name() + "." + md.Schema - // SQLite doesn't support catalog, but we conventionally set it to "default" - md.Catalog = "default" - md.Location = g.src.Location - - md.DBProperties, err = getDBProperties(ctx, g.db) - if err != nil { - return nil, err - } - - if noSchema { - return md, nil - } - - md.Tables, err = getAllTableMetadata(ctx, g.db, md.Schema) - if err != nil { - return nil, err - } - - for _, tbl := range md.Tables { - if tbl.TableType == sqlz.TableTypeTable { - md.TableCount++ - } else if tbl.TableType == sqlz.TableTypeView { - md.ViewCount++ - } - } - - return md, nil -} - -// Close implements driver.Grip. Subsequent calls to Close are no-op and -// return the same error. -func (g *grip) Close() error { - g.closeOnce.Do(func() { - g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - g.closeErr = errw(g.db.Close()) - }) - - return g.closeErr -} - var _ driver.ScratchSrcFunc = NewScratchSource // NewScratchSource returns a new scratch src. The supplied fpath diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index a7c3d474e..42c411959 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -138,7 +138,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, x if sheetTbls[i] == nil { // tblDef can be nil if its sheet is empty (has no data). skipped++ - bar.IncrBy(1) + bar.Incr(1) continue } @@ -146,7 +146,7 @@ func ingestXLSX(ctx context.Context, src *source.Source, destGrip driver.Grip, x return err } ingestCount++ - bar.IncrBy(1) + bar.Incr(1) } log.Debug("Sheets ingested", diff --git a/libsq/core/progress/bars.go b/libsq/core/progress/bars.go index 7ba2e0d39..c0b2da961 100644 --- a/libsq/core/progress/bars.go +++ b/libsq/core/progress/bars.go @@ -46,7 +46,7 @@ func (p *Progress) NewByteCounter(msg string, size int64, opts ...Opt) *Bar { // defer bar.Stop() // // for i := 0; i < 100; i++ { -// bar.IncrBy(1) +// bar.Incr(1) // time.Sleep(100 * time.Millisecond) // } // diff --git a/libsq/core/progress/io.go b/libsq/core/progress/io.go index 256625519..a5c14a9c1 100644 --- a/libsq/core/progress/io.go +++ b/libsq/core/progress/io.go @@ -76,7 +76,7 @@ func (w *progWriter) Write(p []byte) (n int, err error) { } n, err = w.w.Write(p) - w.b.IncrBy(n) + w.b.Incr(n) if err != nil { w.b.Stop() } @@ -172,7 +172,7 @@ func (r *progReader) Read(p []byte) (n int, err error) { } n, err = r.r.Read(p) - r.b.IncrBy(n) + r.b.Incr(n) if err != nil { r.b.Stop() } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 2e66c1400..2d73226ea 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -2,7 +2,7 @@ // Use progress.New to create a new progress widget container. // That widget should be added to a context using progress.NewContext, // and retrieved via progress.FromContext. Invoke one of the Progress.NewX -// methods to create a new progress.Bar. Invoke Bar.IncrBy to increment +// methods to create a new progress.Bar. Invoke Bar.Incr to increment // the bar's progress, and invoke Bar.Stop to stop the bar. Be sure // to invoke Progress.Stop when the progress widget is no longer needed. // @@ -42,7 +42,7 @@ func DebugDelay() { } } -type ctxKey struct{} +type progCtxKey struct{} // NewContext returns ctx with p added as a value. func NewContext(ctx context.Context, p *Progress) context.Context { @@ -50,7 +50,7 @@ func NewContext(ctx context.Context, p *Progress) context.Context { ctx = context.Background() } - return context.WithValue(ctx, ctxKey{}, p) + return context.WithValue(ctx, progCtxKey{}, p) } // FromContext returns the [Progress] added to ctx via NewContext, @@ -61,7 +61,7 @@ func FromContext(ctx context.Context) *Progress { return nil } - val := ctx.Value(ctxKey{}) + val := ctx.Value(progCtxKey{}) if val == nil { return nil } @@ -73,6 +73,35 @@ func FromContext(ctx context.Context) *Progress { return nil } +type barCtxKey struct{} + +// NewBarContext returns ctx with bar added as a value. This +// context can be used in conjunction with progress.Incr to increment bar. +func NewBarContext(ctx context.Context, bar *Bar) context.Context { + if ctx == nil { + ctx = context.Background() + } + + return context.WithValue(ctx, barCtxKey{}, bar) +} + +// Incr increments the progress of the outermost bar in ctx by amount n. +// Use in conjunction with a context returned from NewBarContext. +func Incr(ctx context.Context, n int) { + if ctx == nil { + return + } + + val := ctx.Value(barCtxKey{}) + if val == nil { + return + } + + if b, ok := val.(*Bar); ok { + b.Incr(n) + } +} + // New returns a new Progress instance, which is a container for progress bars. // The returned Progress instance is safe for concurrent use, and all of its // public methods can be safely invoked on a nil Progress. The caller is @@ -359,7 +388,7 @@ func (p *Progress) newBar(cfg *barConfig, opts []Opt) *Bar { } // Bar represents a single progress bar. The caller should invoke -// [Bar.IncrBy] as necessary to increment the bar's progress. When +// [Bar.Incr] as necessary to increment the bar's progress. When // the bar is complete, the caller should invoke [Bar.Stop]. All // methods are safe to call on a nil Bar. type Bar struct { @@ -389,9 +418,9 @@ type Bar struct { incrStash *atomic.Int64 } -// IncrBy increments progress by amount of n. It is safe to +// Incr increments progress by amount n. It is safe to // call IncrBy on a nil Bar. -func (b *Bar) IncrBy(n int) { +func (b *Bar) Incr(n int) { if b == nil { return } diff --git a/libsq/driver/record.go b/libsq/driver/record.go index 4af2e58cf..5dd235caf 100644 --- a/libsq/driver/record.go +++ b/libsq/driver/record.go @@ -467,7 +467,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, } bi.written.Add(affected) - pbar.IncrBy(int(affected)) + pbar.Incr(int(affected)) progress.DebugDelay() if rec == nil { @@ -511,7 +511,7 @@ func NewBatchInsert(ctx context.Context, msg string, drvr SQLDriver, db sqlz.DB, } bi.written.Add(affected) - pbar.IncrBy(int(affected)) + pbar.Incr(int(affected)) progress.DebugDelay() // We're done From 90521e6ff1effc57ff2f5ca0dfdee6f866f5cb23 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 13:24:26 -0700 Subject: [PATCH 190/195] postgres metadata progress bar --- drivers/postgres/grip.go | 76 ++++++++++++++++++++++++++++++++ drivers/postgres/metadata.go | 14 ++++++ drivers/postgres/postgres.go | 53 ---------------------- drivers/sqlite3/grip.go | 14 +++--- drivers/sqlite3/metadata.go | 25 +++++------ drivers/sqlite3/metadata_test.go | 9 ++-- drivers/sqlite3/pragma.go | 4 +- drivers/sqlite3/sqlite3.go | 2 +- libsq/core/progress/progress.go | 3 +- 9 files changed, 117 insertions(+), 83 deletions(-) create mode 100644 drivers/postgres/grip.go diff --git a/drivers/postgres/grip.go b/drivers/postgres/grip.go new file mode 100644 index 000000000..a2f50d2f2 --- /dev/null +++ b/drivers/postgres/grip.go @@ -0,0 +1,76 @@ +package postgres + +import ( + "context" + "database/sql" + "log/slog" + + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/driver" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/metadata" +) + +// grip is the postgres implementation of driver.Grip. +type grip struct { + log *slog.Logger + drvr *driveri + db *sql.DB + src *source.Source +} + +// DB implements driver.Grip. +func (g *grip) DB(context.Context) (*sql.DB, error) { + return g.db, nil +} + +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.drvr +} + +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src +} + +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { + db, err := g.DB(ctx) + if err != nil { + return nil, err + } + + bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+"."+tblName+": read schema", "item") + defer bar.Stop() + ctx = progress.NewBarContext(ctx, bar) + + return getTableMetadata(ctx, db, tblName) +} + +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + db, err := g.DB(ctx) + if err != nil { + return nil, err + } + + bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+": read schema", "item") + defer bar.Stop() + ctx = progress.NewBarContext(ctx, bar) + + return getSourceMetadata(ctx, g.src, db, noSchema) +} + +// Close implements driver.Grip. +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) + + err := g.db.Close() + if err != nil { + return errw(err) + } + return nil +} diff --git a/drivers/postgres/metadata.go b/drivers/postgres/metadata.go index e8cf964f8..d82b12f2b 100644 --- a/drivers/postgres/metadata.go +++ b/drivers/postgres/metadata.go @@ -9,6 +9,8 @@ import ( "strconv" "strings" + "github.com/neilotoole/sq/libsq/core/progress" + "golang.org/x/sync/errgroup" "github.com/neilotoole/sq/libsq/core/errz" @@ -199,6 +201,8 @@ current_setting('server_version'), version(), "current_user"()` if err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() if !schema.Valid { return nil, errz.New("NULL value for current_schema(): check privileges and search_path") @@ -305,6 +309,8 @@ func getPgSettings(ctx context.Context, db sqlz.DB) (map[string]any, error) { if err = rows.Scan(&name, &setting, &typ); err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() // Narrow the setting value bool, int, etc. val = setting @@ -362,6 +368,8 @@ ORDER BY table_name` return nil, errw(err) } tblNames = append(tblNames, s) + progress.Incr(ctx, 1) + progress.DebugDelay() } err = closeRows(rows) @@ -393,6 +401,8 @@ AND table_name = $1` if err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() tblMeta := tblMetaFromPgTable(pgTbl) if tblMeta.Name != tblName { @@ -552,6 +562,8 @@ ORDER BY cols.table_catalog, cols.table_schema, cols.table_name, cols.ordinal_po return nil, err } + progress.Incr(ctx, 1) + progress.DebugDelay() cols = append(cols, col) } err = closeRows(rows) @@ -641,6 +653,8 @@ WHERE kcu.table_catalog = current_catalog AND kcu.table_schema = current_schema( return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() constraints = append(constraints, pgc) } err = closeRows(rows) diff --git a/drivers/postgres/postgres.go b/drivers/postgres/postgres.go index 05fe37bb2..916a5ffcd 100644 --- a/drivers/postgres/postgres.go +++ b/drivers/postgres/postgres.go @@ -743,59 +743,6 @@ func (d *driveri) RecordMeta(ctx context.Context, colTypes []*sql.ColumnType) ( return recMeta, mungeFn, nil } -// grip is the postgres implementation of driver.Grip. -type grip struct { - log *slog.Logger - drvr *driveri - db *sql.DB - src *source.Source -} - -// DB implements driver.Grip. -func (g *grip) DB(context.Context) (*sql.DB, error) { - return g.db, nil -} - -// SQLDriver implements driver.Grip. -func (g *grip) SQLDriver() driver.SQLDriver { - return g.drvr -} - -// Source implements driver.Grip. -func (g *grip) Source() *source.Source { - return g.src -} - -// TableMetadata implements driver.Grip. -func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - db, err := g.DB(ctx) - if err != nil { - return nil, err - } - - return getTableMetadata(ctx, db, tblName) -} - -// SourceMetadata implements driver.Grip. -func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - db, err := g.DB(ctx) - if err != nil { - return nil, err - } - return getSourceMetadata(ctx, g.src, db, noSchema) -} - -// Close implements driver.Grip. -func (g *grip) Close() error { - g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - - err := g.db.Close() - if err != nil { - return errw(err) - } - return nil -} - // doRetry executes fn with retry on isErrTooManyConnections. func doRetry(ctx context.Context, fn func() error) error { maxRetryInterval := driver.OptMaxRetryInterval.Get(options.FromContext(ctx)) diff --git a/drivers/sqlite3/grip.go b/drivers/sqlite3/grip.go index 497904358..8a8d5d045 100644 --- a/drivers/sqlite3/grip.go +++ b/drivers/sqlite3/grip.go @@ -52,18 +52,20 @@ func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Tab return nil, err } - bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+"."+tblName+": reading schema", "row") + bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+"."+tblName+": read schema", "item") defer bar.Stop() + ctx = progress.NewBarContext(ctx, bar) - return getTableMetadata(ctx, db, bar.Incr, tblName) + return getTableMetadata(ctx, db, tblName) } // SourceMetadata implements driver.Grip. func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { // https://stackoverflow.com/questions/9646353/how-to-find-sqlite-database-file-version - bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+": reading schema", "row") + bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+": read schema", "item") defer bar.Stop() + ctx = progress.NewBarContext(ctx, bar) md := &metadata.Source{Handle: g.src.Handle, Driver: Type, DBDriver: dbDrvr} @@ -78,7 +80,7 @@ func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou if err != nil { return nil, errw(err) } - bar.Incr(1) + progress.Incr(ctx, 1) md.DBProduct = "SQLite3 v" + md.DBVersion @@ -94,7 +96,7 @@ func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou md.Catalog = "default" md.Location = g.src.Location - md.DBProperties, err = getDBProperties(ctx, g.db, bar.Incr) + md.DBProperties, err = getDBProperties(ctx, g.db) if err != nil { return nil, err } @@ -103,7 +105,7 @@ func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Sou return md, nil } - md.Tables, err = getAllTableMetadata(ctx, g.db, bar.Incr, md.Schema) + md.Tables, err = getAllTableMetadata(ctx, g.db, md.Schema) if err != nil { return nil, err } diff --git a/drivers/sqlite3/metadata.go b/drivers/sqlite3/metadata.go index ccb895cfb..601dcf9da 100644 --- a/drivers/sqlite3/metadata.go +++ b/drivers/sqlite3/metadata.go @@ -22,11 +22,6 @@ import ( "github.com/neilotoole/sq/libsq/source/metadata" ) -// IncrFunc is a function that increments a counter. -// It is used to gather statistics, in particular to update -// progress bars. -type IncrFunc func(int) - // recordMetaFromColumnTypes returns record.Meta for colTypes. func recordMetaFromColumnTypes(ctx context.Context, colTypes []*sql.ColumnType, ) (record.Meta, error) { @@ -269,7 +264,7 @@ func DBTypeForKind(knd kind.Kind) string { } // getTableMetadata returns metadata for tblName in db. -func getTableMetadata(ctx context.Context, db sqlz.DB, incr IncrFunc, tblName string) (*metadata.Table, error) { +func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*metadata.Table, error) { log := lg.FromContext(ctx) tblMeta := &metadata.Table{Name: tblName} // Note that there's no easy way of getting the physical size of @@ -289,7 +284,7 @@ func getTableMetadata(ctx context.Context, db sqlz.DB, incr IncrFunc, tblName st if err != nil { return nil, errw(err) } - incr(1) + progress.Incr(ctx, 1) switch { case isVirtualTbl.Valid && isVirtualTbl.Bool: @@ -315,7 +310,7 @@ func getTableMetadata(ctx context.Context, db sqlz.DB, incr IncrFunc, tblName st defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows) for rows.Next() { - incr(1) + progress.Incr(ctx, 1) progress.DebugDelay() col := &metadata.Column{} @@ -335,7 +330,7 @@ func getTableMetadata(ctx context.Context, db sqlz.DB, incr IncrFunc, tblName st if col.BaseType, err = getTypeOfColumn(ctx, db, tblMeta.Name, col.Name); err != nil { return nil, err } - incr(1) + progress.Incr(ctx, 1) } col.PrimaryKey = pkValue.Int64 > 0 // pkVal can be 0,1,2 etc @@ -360,7 +355,7 @@ func getTableMetadata(ctx context.Context, db sqlz.DB, incr IncrFunc, tblName st // set Table.FQName; it is not used to select which schema // to introspect. // The supplied incr func should be invoked for each row read from the DB. -func getAllTableMetadata(ctx context.Context, db sqlz.DB, incr IncrFunc, schemaName string) ([]*metadata.Table, error) { +func getAllTableMetadata(ctx context.Context, db sqlz.DB, schemaName string) ([]*metadata.Table, error) { log := lg.FromContext(ctx) // This query returns a row for each column of each table, // order by table name then col id (ordinal). @@ -401,7 +396,7 @@ ORDER BY m.name, p.cid defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows) for rows.Next() { - incr(1) + progress.Incr(ctx, 1) progress.DebugDelay() select { case <-ctx.Done(): @@ -442,7 +437,7 @@ ORDER BY m.name, p.cid if col.BaseType, err = getTypeOfColumn(ctx, db, curTblName, col.Name); err != nil { return nil, err } - incr(1) + progress.Incr(ctx, 1) } if curTblMeta == nil || curTblMeta.Name != curTblName { @@ -484,7 +479,7 @@ ORDER BY m.name, p.cid // Separately, we need to get the row counts for the tables var rowCounts []int64 - rowCounts, err = getTblRowCounts(ctx, db, incr, tblNames) + rowCounts, err = getTblRowCounts(ctx, db, tblNames) if err != nil { return nil, errw(err) } @@ -497,7 +492,7 @@ ORDER BY m.name, p.cid } // getTblRowCounts returns the number of rows in each table. -func getTblRowCounts(ctx context.Context, db sqlz.DB, incr IncrFunc, tblNames []string) ([]int64, error) { +func getTblRowCounts(ctx context.Context, db sqlz.DB, tblNames []string) ([]int64, error) { log := lg.FromContext(ctx) // See: https://stackoverflow.com/questions/7524612/how-to-count-rows-from-multiple-tables-in-sqlite @@ -556,7 +551,7 @@ func getTblRowCounts(ctx context.Context, db sqlz.DB, incr IncrFunc, tblNames [] return nil, errw(err) } j++ - incr(1) + progress.Incr(ctx, 1) progress.DebugDelay() } diff --git a/drivers/sqlite3/metadata_test.go b/drivers/sqlite3/metadata_test.go index 9fc691079..d41cb3591 100644 --- a/drivers/sqlite3/metadata_test.go +++ b/drivers/sqlite3/metadata_test.go @@ -304,7 +304,7 @@ func TestGetTblRowCounts(t *testing.T) { tblNames := createTypeTestTbls(th, src, numTables, true) - counts, err := sqlite3.GetTblRowCounts(th.Context, db, func(i int) {}, tblNames) + counts, err := sqlite3.GetTblRowCounts(th.Context, db, tblNames) require.NoError(t, err) require.Equal(t, len(tblNames), len(counts)) } @@ -318,7 +318,7 @@ func BenchmarkGetTblRowCounts(b *testing.B) { testCases := []struct { name string - fn func(ctx context.Context, db sqlz.DB, incr sqlite3.IncrFunc, tblNames []string) ([]int64, error) + fn func(ctx context.Context, db sqlz.DB, tblNames []string) ([]int64, error) }{ {name: "benchGetTblRowCountsBaseline", fn: benchGetTblRowCountsBaseline}, {name: "getTblRowCounts", fn: sqlite3.GetTblRowCounts}, @@ -329,7 +329,7 @@ func BenchmarkGetTblRowCounts(b *testing.B) { b.Run(tc.name, func(b *testing.B) { for n := 0; n < b.N; n++ { - counts, err := tc.fn(th.Context, db, func(int) {}, tblNames) + counts, err := tc.fn(th.Context, db, tblNames) require.NoError(b, err) require.Len(b, counts, len(tblNames)) } @@ -343,8 +343,7 @@ func BenchmarkGetTblRowCounts(b *testing.B) { // benchGetTblRowCountsBaseline is a baseline impl of getTblRowCounts // for benchmark comparison. -func benchGetTblRowCountsBaseline(ctx context.Context, db sqlz.DB, - _ sqlite3.IncrFunc, tblNames []string, +func benchGetTblRowCountsBaseline(ctx context.Context, db sqlz.DB, tblNames []string, ) ([]int64, error) { tblCounts := make([]int64, len(tblNames)) diff --git a/drivers/sqlite3/pragma.go b/drivers/sqlite3/pragma.go index 0302b75f8..63b99e4c6 100644 --- a/drivers/sqlite3/pragma.go +++ b/drivers/sqlite3/pragma.go @@ -19,7 +19,7 @@ import ( // be invoked for each row read from the DB. // // See: https://www.sqlite.org/pragma.html -func getDBProperties(ctx context.Context, db sqlz.DB, incr IncrFunc) (map[string]any, error) { +func getDBProperties(ctx context.Context, db sqlz.DB) (map[string]any, error) { pragmas, err := listPragmaNames(ctx, db) if err != nil { return nil, err @@ -33,7 +33,7 @@ func getDBProperties(ctx context.Context, db sqlz.DB, incr IncrFunc) (map[string return nil, errz.Wrapf(errw(err), "read pragma: %s", pragma) } - incr(1) + progress.Incr(ctx, 1) progress.DebugDelay() if val != nil { diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index a9456eed5..eb30fccc9 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -112,7 +112,7 @@ func (d *driveri) ErrWrapFunc() func(error) error { // DBProperties implements driver.SQLDriver. func (d *driveri) DBProperties(ctx context.Context, db sqlz.DB) (map[string]any, error) { - return getDBProperties(ctx, db, nil) + return getDBProperties(ctx, db) } // DriverMetadata implements driver.Driver. diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 2d73226ea..3610c7eb7 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -76,7 +76,8 @@ func FromContext(ctx context.Context) *Progress { type barCtxKey struct{} // NewBarContext returns ctx with bar added as a value. This -// context can be used in conjunction with progress.Incr to increment bar. +// context can be used in conjunction with progress.Incr to increment +// the progress bar. func NewBarContext(ctx context.Context, bar *Bar) context.Context { if ctx == nil { ctx = context.Background() From 2280437ed7f0bf128f6879b9a83f52c2d292643b Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 15:38:45 -0700 Subject: [PATCH 191/195] testh .gitignore --- testh/.gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 testh/.gitignore diff --git a/testh/.gitignore b/testh/.gitignore new file mode 100644 index 000000000..4f1dd02f2 --- /dev/null +++ b/testh/.gitignore @@ -0,0 +1 @@ +./progress-remove.test.sh From 03f54b05687aeea6fc13ff9bdf74503724e0f50a Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 15:43:53 -0700 Subject: [PATCH 192/195] Grip.SourceMetadata and TableMetadata now have progress bars --- .gitignore | 1 + .golangci.yml | 2 +- cli/output.go | 1 + drivers/mysql/grip.go | 61 ++++++++++++++++++++++++ drivers/mysql/metadata.go | 16 +++++++ drivers/mysql/mysql.go | 39 --------------- drivers/postgres/grip.go | 14 +----- drivers/sqlserver/grip.go | 84 +++++++++++++++++++++++++++++++++ drivers/sqlserver/metadata.go | 20 ++++++-- drivers/sqlserver/properties.go | 6 +++ drivers/sqlserver/sqlserver.go | 54 --------------------- libsq/core/ioz/ioz.go | 2 +- libsq/core/progress/progress.go | 84 ++++++++++++++++++++++----------- testh/.gitignore | 1 - 14 files changed, 247 insertions(+), 138 deletions(-) create mode 100644 drivers/mysql/grip.go create mode 100644 drivers/sqlserver/grip.go delete mode 100644 testh/.gitignore diff --git a/.gitignore b/.gitignore index ef74c33fd..8e95d3946 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ goreleaser-test.sh /cli/test.db /*.db /.CHANGELOG.delta.md +/testh/progress-remove.test.sh diff --git a/.golangci.yml b/.golangci.yml index c569e3ca5..22c6d24bb 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -283,7 +283,7 @@ linters-settings: disabled: false # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#empty-lines - name: empty-lines - disabled: false + disabled: true # Covered by "whitespace" linter. # https://github.com/mgechev/revive/blob/master/RULES_DESCRIPTIONS.md#enforce-map-style - name: enforce-map-style disabled: true diff --git a/cli/output.go b/cli/output.go index 129e1ba76..b69928fb9 100644 --- a/cli/output.go +++ b/cli/output.go @@ -470,6 +470,7 @@ func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Option out2 = ioz.NotifyOnceWriter(out2, func() { lg.FromContext(ctx).Debug("Output stream is being written to; removing progress widget") pb.Stop() + lg.FromContext(ctx).Debug("Progress widget should be removed now") }) cmd.SetContext(progress.NewContext(ctx, pb)) diff --git a/drivers/mysql/grip.go b/drivers/mysql/grip.go new file mode 100644 index 000000000..4e13c75f8 --- /dev/null +++ b/drivers/mysql/grip.go @@ -0,0 +1,61 @@ +package mysql + +import ( + "context" + "database/sql" + "log/slog" + + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/driver" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/metadata" +) + +// grip implements driver.Grip. +type grip struct { + log *slog.Logger + db *sql.DB + src *source.Source + drvr *driveri +} + +// DB implements driver.Grip. +func (g *grip) DB(context.Context) (*sql.DB, error) { + return g.db, nil +} + +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.drvr +} + +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src +} + +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { + bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+"."+tblName+": read schema", "item") + defer bar.Stop() + ctx = progress.NewBarContext(ctx, bar) + + return getTableMetadata(ctx, g.db, tblName) +} + +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+": read schema", "item") + defer bar.Stop() + ctx = progress.NewBarContext(ctx, bar) + + return getSourceMetadata(ctx, g.src, g.db, noSchema) +} + +// Close implements driver.Grip. +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) + return errw(g.db.Close()) +} diff --git a/drivers/mysql/metadata.go b/drivers/mysql/metadata.go index 6379765a8..7b726712d 100644 --- a/drivers/mysql/metadata.go +++ b/drivers/mysql/metadata.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/go-sql-driver/mysql" "github.com/samber/lo" "golang.org/x/sync/errgroup" @@ -186,6 +188,8 @@ WHERE TABLE_SCHEMA = DATABASE() AND TABLE_NAME = ?` if err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() tblMeta.TableType = canonicalTableType(tblMeta.DBTableType) tblMeta.FQName = schema + "." + tblMeta.Name @@ -230,6 +234,8 @@ ORDER BY cols.ordinal_position ASC` if err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() if strings.EqualFold("YES", isNullable) { col.Nullable = true @@ -343,6 +349,8 @@ func setSourceSummaryMeta(ctx context.Context, db sqlz.DB, md *metadata.Source) if err != nil { return errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() md.Name = schema md.Schema = schema @@ -375,6 +383,9 @@ func getDBProperties(ctx context.Context, db sqlz.DB) (map[string]any, error) { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() + // Narrow setting to bool or int if possible. var ( v any = val @@ -454,6 +465,9 @@ ORDER BY c.TABLE_NAME ASC, c.ORDINAL_POSITION ASC` return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() + if !curTblName.Valid || !colName.Valid { // table may have been dropped during metadata collection log.Debug("Table not found during metadata collection") @@ -618,6 +632,8 @@ func getTableRowCounts(ctx context.Context, db sqlz.DB, tblNames []string) (map[ if err = rows.Scan(&tblName, &count); err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() m[tblName] = count } diff --git a/drivers/mysql/mysql.go b/drivers/mysql/mysql.go index 3e4e69e5c..6d772a5b0 100644 --- a/drivers/mysql/mysql.go +++ b/drivers/mysql/mysql.go @@ -573,45 +573,6 @@ func (d *driveri) Truncate(ctx context.Context, src *source.Source, tbl string, return beforeCount, errw(tx.Commit()) } -// grip implements driver.Grip. -type grip struct { - log *slog.Logger - db *sql.DB - src *source.Source - drvr *driveri -} - -// DB implements driver.Grip. -func (g *grip) DB(context.Context) (*sql.DB, error) { - return g.db, nil -} - -// SQLDriver implements driver.Grip. -func (g *grip) SQLDriver() driver.SQLDriver { - return g.drvr -} - -// Source implements driver.Grip. -func (g *grip) Source() *source.Source { - return g.src -} - -// TableMetadata implements driver.Grip. -func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - return getTableMetadata(ctx, g.db, tblName) -} - -// SourceMetadata implements driver.Grip. -func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - return getSourceMetadata(ctx, g.src, g.db, noSchema) -} - -// Close implements driver.Grip. -func (g *grip) Close() error { - g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - return errw(g.db.Close()) -} - // dsnFromLocation builds the mysql driver DSN from src.Location. // If parseTime is true, the param "parseTime=true" is added. This // is because of: https://stackoverflow.com/questions/29341590/how-to-parse-time-from-database/29343013#29343013 diff --git a/drivers/postgres/grip.go b/drivers/postgres/grip.go index a2f50d2f2..d7f9282fa 100644 --- a/drivers/postgres/grip.go +++ b/drivers/postgres/grip.go @@ -38,30 +38,20 @@ func (g *grip) Source() *source.Source { // TableMetadata implements driver.Grip. func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - db, err := g.DB(ctx) - if err != nil { - return nil, err - } - bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+"."+tblName+": read schema", "item") defer bar.Stop() ctx = progress.NewBarContext(ctx, bar) - return getTableMetadata(ctx, db, tblName) + return getTableMetadata(ctx, g.db, tblName) } // SourceMetadata implements driver.Grip. func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - db, err := g.DB(ctx) - if err != nil { - return nil, err - } - bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+": read schema", "item") defer bar.Stop() ctx = progress.NewBarContext(ctx, bar) - return getSourceMetadata(ctx, g.src, db, noSchema) + return getSourceMetadata(ctx, g.src, g.db, noSchema) } // Close implements driver.Grip. diff --git a/drivers/sqlserver/grip.go b/drivers/sqlserver/grip.go new file mode 100644 index 000000000..159034a1d --- /dev/null +++ b/drivers/sqlserver/grip.go @@ -0,0 +1,84 @@ +package sqlserver + +import ( + "context" + "database/sql" + "log/slog" + + "github.com/neilotoole/sq/libsq/core/lg" + + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/driver" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/metadata" +) + +// grip implements driver.Grip. +type grip struct { + log *slog.Logger + drvr *driveri + db *sql.DB + src *source.Source +} + +var _ driver.Grip = (*grip)(nil) + +// DB implements driver.Grip. +func (g *grip) DB(context.Context) (*sql.DB, error) { + return g.db, nil +} + +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.drvr +} + +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src +} + +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { + bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+"."+tblName+": read schema", "item") + defer func() { + lg.FromContext(ctx).Warn("Before bar stop") + bar.Stop() + lg.FromContext(ctx).Warn("After bar stop") + }() + ctx = progress.NewBarContext(ctx, bar) + + const query = `SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_TYPE +FROM INFORMATION_SCHEMA.TABLES +WHERE TABLE_NAME = @p1` + + var catalog, schema, tblType string + err := g.db.QueryRowContext(ctx, query, tblName).Scan(&catalog, &schema, &tblType) + if err != nil { + return nil, errw(err) + } + progress.Incr(ctx, 1) + progress.DebugDelay() + + // TODO: getTableMetadata can cause deadlock in the DB. Needs further investigation. + // But a quick hack would be to use retry on a deadlock error. + return getTableMetadata(ctx, g.db, catalog, schema, tblName, tblType) +} + +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+": read schema", "item") + defer bar.Stop() + ctx = progress.NewBarContext(ctx, bar) + + return getSourceMetadata(ctx, g.src, g.db, noSchema) +} + +// Close implements driver.Grip. +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) + + return errw(g.db.Close()) +} diff --git a/drivers/sqlserver/metadata.go b/drivers/sqlserver/metadata.go index 8d81415a9..a11ec97de 100644 --- a/drivers/sqlserver/metadata.go +++ b/drivers/sqlserver/metadata.go @@ -8,6 +8,8 @@ import ( "strconv" "strings" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/c2h5oh/datasize" "golang.org/x/sync/errgroup" @@ -131,6 +133,8 @@ GROUP BY database_id) AS total_size_bytes` if err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() md.Name = catalog md.FQName = catalog + "." + schema @@ -238,6 +242,8 @@ func getTableMetadata(ctx context.Context, db sqlz.DB, tblCatalog, if err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() if rowCount.Valid { tblMeta.RowCount, err = strconv.ParseInt(strings.TrimSpace(rowCount.String), 10, 64) @@ -251,6 +257,8 @@ func getTableMetadata(ctx context.Context, db sqlz.DB, tblCatalog, if err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() } if reserved.Valid { @@ -340,6 +348,8 @@ ORDER BY TABLE_NAME ASC, TABLE_TYPE ASC` if err != nil { return nil, nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() tblNames = append(tblNames, tblName) tblTypes = append(tblTypes, tblType) @@ -387,7 +397,8 @@ func getColumnMeta(ctx context.Context, db sqlz.DB, tblCatalog, tblSchema, tblNa if err != nil { return nil, errw(err) } - + progress.Incr(ctx, 1) + progress.DebugDelay() cols = append(cols, c) } @@ -416,11 +427,12 @@ func getConstraints(ctx context.Context, db sqlz.DB, tblCatalog, tblSchema, tblN if err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() defer lg.WarnIfCloseError(log, lgm.CloseDBRows, rows) var constraints []constraintMeta - for rows.Next() { c := constraintMeta{} err = rows.Scan(&c.TableCatalog, &c.TableSchema, &c.TableName, &c.ConstraintType, &c.ColumnName, @@ -428,6 +440,8 @@ func getConstraints(ctx context.Context, db sqlz.DB, tblCatalog, tblSchema, tblN if err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() constraints = append(constraints, c) } @@ -441,7 +455,7 @@ func getConstraints(ctx context.Context, db sqlz.DB, tblCatalog, tblSchema, tblN // constraintMeta models constraint metadata from information schema. type constraintMeta struct { - TableCatalog string `db:"TABLE_CATALOG"` + TableCatalog string `db:"TABLE_CATALOG"` // REVISIT: why do we have the `db` tag here? TableSchema string `db:"TABLE_SCHEMA"` TableName string `db:"TABLE_NAME"` ConstraintType string `db:"CONSTRAINT_TYPE"` diff --git a/drivers/sqlserver/properties.go b/drivers/sqlserver/properties.go index 82cf1c8e0..563ce957a 100644 --- a/drivers/sqlserver/properties.go +++ b/drivers/sqlserver/properties.go @@ -3,6 +3,8 @@ package sqlserver import ( "context" + "github.com/neilotoole/sq/libsq/core/progress" + "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/sqlz" @@ -47,6 +49,8 @@ func getSysConfigurations(ctx context.Context, db sqlz.DB) (map[string]any, erro if err = rows.Scan(&name, &val); err != nil { return nil, errw(err) } + progress.Incr(ctx, 1) + progress.DebugDelay() m[name] = val } @@ -83,6 +87,8 @@ func getServerProperties(ctx context.Context, db sqlz.DB) (map[string]any, error if val == nil { continue } + progress.Incr(ctx, 1) + progress.DebugDelay() m[name] = val } diff --git a/drivers/sqlserver/sqlserver.go b/drivers/sqlserver/sqlserver.go index d4ac77673..3fdd0d12b 100644 --- a/drivers/sqlserver/sqlserver.go +++ b/drivers/sqlserver/sqlserver.go @@ -672,60 +672,6 @@ func (d *driveri) getTableColsMeta(ctx context.Context, db sqlz.DB, tblName stri return destCols, nil } -// grip implements driver.Grip. -type grip struct { - log *slog.Logger - drvr *driveri - db *sql.DB - src *source.Source -} - -var _ driver.Grip = (*grip)(nil) - -// DB implements driver.Grip. -func (g *grip) DB(context.Context) (*sql.DB, error) { - return g.db, nil -} - -// SQLDriver implements driver.Grip. -func (g *grip) SQLDriver() driver.SQLDriver { - return g.drvr -} - -// Source implements driver.Grip. -func (g *grip) Source() *source.Source { - return g.src -} - -// TableMetadata implements driver.Grip. -func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - const query = `SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_TYPE -FROM INFORMATION_SCHEMA.TABLES -WHERE TABLE_NAME = @p1` - - var catalog, schema, tblType string - err := g.db.QueryRowContext(ctx, query, tblName).Scan(&catalog, &schema, &tblType) - if err != nil { - return nil, errw(err) - } - - // TODO: getTableMetadata can cause deadlock in the DB. Needs further investigation. - // But a quick hack would be to use retry on a deadlock error. - return getTableMetadata(ctx, g.db, catalog, schema, tblName, tblType) -} - -// SourceMetadata implements driver.Grip. -func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - return getSourceMetadata(ctx, g.src, g.db, noSchema) -} - -// Close implements driver.Grip. -func (g *grip) Close() error { - g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - - return errw(g.db.Close()) -} - // newStmtExecFunc returns a StmtExecFunc that has logic to deal with // the "identity insert" error. If the error is encountered, setIdentityInsert // is called and stmt is executed again. diff --git a/libsq/core/ioz/ioz.go b/libsq/core/ioz/ioz.go index 2d88a4db0..4a590d854 100644 --- a/libsq/core/ioz/ioz.go +++ b/libsq/core/ioz/ioz.go @@ -280,7 +280,7 @@ var _ io.Writer = (*notifyOnceWriter)(nil) type notifyOnceWriter struct { w io.Writer fn func() - doneCh chan struct{} + doneCh chan struct{} // REVISIT: Do we need doneCh? notifyOnce sync.Once } diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 3610c7eb7..2775d9c1f 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -86,8 +86,14 @@ func NewBarContext(ctx context.Context, bar *Bar) context.Context { return context.WithValue(ctx, barCtxKey{}, bar) } -// Incr increments the progress of the outermost bar in ctx by amount n. -// Use in conjunction with a context returned from NewBarContext. +// Incr increments the progress of the outermost bar (if any) in ctx +// by amount n. Use in conjunction with a context returned from NewBarContext. +// It safe to invoke Incr on a nil context or a context that doesn't +// contain a Bar. +// +// NOTE: This is a bit of an experiment. I'm a bit hesitant in going even +// further with context-based logic, as it's not clear to me that it's +// a good path to be on. So, it's possible this mechanism may be removed. func Incr(ctx context.Context, n int) { if ctx == nil { return @@ -125,7 +131,7 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors delay: delay, stoppedCh: make(chan struct{}), stopOnce: &sync.Once{}, - refreshCh: make(chan any, 100), + // refreshCh: make(chan any, 100), } // Note that p.ctx is not the same as the arg ctx. This is a bit of a hack @@ -147,29 +153,33 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors opts := []mpb.ContainerOption{ mpb.WithOutput(out), mpb.WithWidth(boxWidth), + // FIXME: switch back to auto refresh? // mpb.WithRefreshRate(refreshRate), - mpb.WithManualRefresh(p.refreshCh), - // mpb.WithAutoRefresh(), // Needed for color in Windows, apparently + // mpb.WithManualRefresh(p.refreshCh), + mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } p.pc = mpb.NewWithContext(ctx, opts...) p.pcInitFn = nil - go func() { - for { - select { - case <-p.stoppedCh: - return - case <-p.ctx.Done(): - return - default: - p.refreshCh <- time.Now() - time.Sleep(refreshRate) - } - } - }() + // go func() { + // for { + // select { + // case <-p.stoppedCh: + // return + // case <-p.ctx.Done(): + // return + // case <-p.refreshCh: + // default: + // // p.refreshCh <- time.Now() + // time.Sleep(refreshRate) + // } + // } + // }() } + // REVISIT: The delay init of pc is no longer required. We can just + // directly call it here. p.pcInitFn() return p } @@ -205,7 +215,7 @@ type Progress struct { stoppedCh chan struct{} stopOnce *sync.Once - refreshCh chan any + // refreshCh chan any ctx context.Context cancelFn context.CancelFunc @@ -244,21 +254,25 @@ func (p *Progress) Stop() { // doStop is probably needlessly complex, but at the time it was written, // there was a bug in the mpb package (to do with delayed render and abort), -// and so was created an extra-paranoid workaround. +// and so was created an extra-paranoid workaround. It's still not clear +// if all of this works to remove the progress bars before content +// is written to stdout. func (p *Progress) doStop() { p.stopOnce.Do(func() { p.pcInitFn = nil lg.FromContext(p.ctx).Debug("Stopping progress widget") defer lg.FromContext(p.ctx).Debug("Stopped progress widget") if p.pc == nil { - close(p.stoppedCh) p.cancelFn() + <-p.ctx.Done() + close(p.stoppedCh) return } if len(p.bars) == 0 { - close(p.stoppedCh) p.cancelFn() + <-p.ctx.Done() + close(p.stoppedCh) return } @@ -278,15 +292,31 @@ func (p *Progress) doStop() { <-b.barStoppedCh // Wait for bar to stop } - p.refreshCh <- time.Now() - close(p.stoppedCh) + // p.refreshCh <- time.Now() + + // close(p.stoppedCh) + + // So, now we REALLY want to wait for the progress widget + // to finish. Alas, the pc.Wait method doesn't seem to + // always remove the bars from the terminal. So, we do + // some probably useless extra steps to hopefully trigger + // the terminal wipe before we return. p.pc.Wait() // Important: we must call cancelFn after pc.Wait() or the bars // may not be removed from the terminal. p.cancelFn() + <-p.ctx.Done() + // We shouldn't need this extra call to pc.Wait, + // but it shouldn't hurt? + // time.Sleep(time.Millisecond) // FIXME: delete + p.pc.Wait() + + // And a tiny sleep, which again, hopefully can be removed + // at some point. + // time.Sleep(time.Millisecond) // FIXME: delete + close(p.stoppedCh) }) - <-p.stoppedCh <-p.ctx.Done() } @@ -476,9 +506,9 @@ func (b *Bar) doStop() { // We *probably* only need to call b.bar.Abort() here? b.bar.SetTotal(-1, true) b.bar.Abort(true) - b.p.refreshCh <- time.Now() + // b.p.refreshCh <- time.Now() b.bar.Wait() - b.p.refreshCh <- time.Now() + // b.p.refreshCh <- time.Now() close(b.barStoppedCh) lg.FromContext(b.p.ctx).Debug("Stopped progress bar") diff --git a/testh/.gitignore b/testh/.gitignore deleted file mode 100644 index 4f1dd02f2..000000000 --- a/testh/.gitignore +++ /dev/null @@ -1 +0,0 @@ -./progress-remove.test.sh From 8fb90265977ad375768f31f6eab3fc5df3afa393 Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 16:00:08 -0700 Subject: [PATCH 193/195] Update go.mod: mpb --- go.mod | 7 +++---- go.sum | 10 ++++------ 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/go.mod b/go.mod index d36739d6f..4c8f42a48 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/neilotoole/sq go 1.21 // See: https://github.com/vbauerster/mpb/issues/136 -require github.com/vbauerster/mpb/v8 v8.7.1-0.20231206170755-3a4a40c73c35 +require github.com/vbauerster/mpb/v8 v8.7.2 // See: https://github.com/djherbis/fscache/pull/21 require github.com/neilotoole/fscache v0.0.0-20231203162946-c9808f16552e @@ -51,11 +51,10 @@ require ( github.com/xo/dburl v0.19.1 github.com/xuri/excelize/v2 v2.8.0 go.uber.org/atomic v1.11.0 - go.uber.org/multierr v1.11.0 golang.org/x/exp v0.0.0-20231127185646-65229373498e golang.org/x/mod v0.14.0 - golang.org/x/net v0.19.0 golang.org/x/sync v0.5.0 + golang.org/x/sys v0.16.0 golang.org/x/term v0.15.0 golang.org/x/text v0.14.0 ) @@ -95,7 +94,7 @@ require ( github.com/xuri/efp v0.0.0-20230802181842-ad255f2331ca // indirect github.com/xuri/nfp v0.0.0-20230819163627-dc951e3ffe1a // indirect golang.org/x/crypto v0.16.0 // indirect - golang.org/x/sys v0.15.0 // indirect + golang.org/x/net v0.19.0 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 8e21d1b2a..ee3e8f942 100644 --- a/go.sum +++ b/go.sum @@ -184,8 +184,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/vbauerster/mpb/v8 v8.7.1-0.20231206170755-3a4a40c73c35 h1:MMBVE5bui8tBLZz7L4K9MX+ZBQ4eMsrX1iMCg0Ex6Lo= -github.com/vbauerster/mpb/v8 v8.7.1-0.20231206170755-3a4a40c73c35/go.mod h1:0RgdqeTpu6cDbdWeSaDvEvfgm9O598rBnRZ09HKaV0k= +github.com/vbauerster/mpb/v8 v8.7.2 h1:SMJtxhNho1MV3OuFgS1DAzhANN1Ejc5Ct+0iSaIkB14= +github.com/vbauerster/mpb/v8 v8.7.2/go.mod h1:ZFnrjzspgDHoxYLGvxIruiNk73GNTPG4YHgVNpR10VY= github.com/xo/dburl v0.19.1 h1:z/K2i8zVf6aRwQ8Szz7MGEUw0VC2472D9SlBqdHDQCU= github.com/xo/dburl v0.19.1/go.mod h1:B7/G9FGungw6ighV8xJNwWYQPMfn3gsi2sn5SE8Bzco= github.com/xuri/efp v0.0.0-20230802181842-ad255f2331ca h1:uvPMDVyP7PXMMioYdyPH+0O+Ta/UO1WFfNYMO3Wz0eg= @@ -197,8 +197,6 @@ github.com/xuri/nfp v0.0.0-20230819163627-dc951e3ffe1a/go.mod h1:WwHg+CVyzlv/TX9 github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= -go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= -go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= @@ -240,8 +238,8 @@ golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= From 6f67533441a78db2c2b873f2ba45736542aa585f Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 18:18:17 -0700 Subject: [PATCH 194/195] PR review and cleanup --- README.md | 3 +- cli/cmd_add_test.go | 11 +- cli/cmd_cache.go | 9 +- cli/cmd_diff.go | 2 +- cli/cmd_inspect.go | 3 +- cli/cmd_inspect_test.go | 6 +- cli/cmd_src_test.go | 7 +- cli/cmd_x.go | 1 - cli/config/yamlstore/yamlstore.go | 1 - cli/diff/data_naive.go | 3 +- cli/diff/diff.go | 3 +- cli/logging.go | 5 +- cli/output.go | 7 +- cli/output/jsonw/configwriter.go | 5 +- cli/output/tablew/configwriter.go | 5 +- cli/output/writers.go | 4 +- cli/output/yamlw/configwriter.go | 5 +- cli/run.go | 15 +- cli/terminal_windows.go | 5 +- cli/testrun/testrun.go | 10 +- drivers/csv/csv.go | 2 +- drivers/csv/csv_test.go | 5 +- drivers/json/ingest.go | 2 +- drivers/json/json.go | 2 +- drivers/mysql/metadata.go | 5 +- drivers/mysql/mysql.go | 2 +- drivers/postgres/metadata.go | 3 +- drivers/postgres/postgres.go | 2 +- drivers/sqlite3/metadata.go | 3 +- drivers/sqlite3/pragma.go | 3 +- drivers/sqlite3/sqlite3.go | 5 +- drivers/sqlserver/grip.go | 8 +- drivers/sqlserver/metadata.go | 3 +- drivers/sqlserver/properties.go | 3 +- drivers/sqlserver/sqlserver.go | 2 +- drivers/userdriver/grip..go | 65 +++++++ drivers/userdriver/userdriver.go | 67 +------- .../xmlud/{xmlimport.go => xmlud.go} | 159 +++++++++--------- .../{xmlimport_test.go => xmlud_test.go} | 7 +- drivers/xlsx/grip.go | 4 +- drivers/xlsx/ingest.go | 2 +- drivers/xlsx/xlsx.go | 2 +- drivers/xlsx/xlsx_test.go | 3 +- libsq/ast/handle.go | 3 +- libsq/core/ioz/checksum/checksum.go | 8 +- libsq/core/ioz/checksum/checksum_test.go | 2 +- libsq/core/ioz/contextio/contextio.go | 1 + libsq/core/ioz/download/download.go | 8 +- libsq/core/ioz/download/handler.go | 2 +- libsq/core/ioz/httpz/httpz.go | 4 +- libsq/core/ioz/httpz/opts.go | 28 +-- libsq/core/ioz/lockfile/lockfile.go | 7 +- libsq/core/loz/loz.go | 2 +- libsq/core/progress/progress.go | 53 ++---- libsq/core/progress/style.go | 3 - libsq/dbwriter.go | 4 +- libsq/driver/grip.go | 8 +- libsq/driver/grips.go | 22 +-- libsq/query_test.go | 4 +- libsq/source/cache.go | 3 +- libsq/source/files.go | 5 +- libsq/source/files_test.go | 3 +- libsq/source/lock.go | 35 ---- testh/testh.go | 23 ++- testh/tu/tu.go | 5 +- 65 files changed, 304 insertions(+), 398 deletions(-) create mode 100644 drivers/userdriver/grip..go rename drivers/userdriver/xmlud/{xmlimport.go => xmlud.go} (78%) rename drivers/userdriver/xmlud/{xmlimport_test.go => xmlud_test.go} (96%) delete mode 100644 libsq/source/lock.go diff --git a/README.md b/README.md index 6a7add9e8..99fc6be15 100644 --- a/README.md +++ b/README.md @@ -324,7 +324,8 @@ See [CHANGELOG.md](./CHANGELOG.md). - The [`log.devmode`](https://sq.io/docs/config#logdevmode) log format is derived from [`lmittmann/tint`](https://github.com/lmittmann/tint). - [`djherbis/fscache`](https://github.com/djherbis/fscache) is used for caching. -- A forked version of lockfile +- A forked version of [`nightlyone/lockfile`](https://github.com/nightlyone/lockfile) is incorporated. +- The human-friendly `text` log format handler is a fork of [`lmittmann/tint`](https://github.com/lmittmann/tint). ## Similar, related, or noteworthy projects diff --git a/cli/cmd_add_test.go b/cli/cmd_add_test.go index c07a0b461..682ff73de 100644 --- a/cli/cmd_add_test.go +++ b/cli/cmd_add_test.go @@ -5,20 +5,19 @@ import ( "path/filepath" "testing" + "github.com/stretchr/testify/require" + + "github.com/neilotoole/sq/cli/testrun" "github.com/neilotoole/sq/drivers/csv" "github.com/neilotoole/sq/drivers/mysql" "github.com/neilotoole/sq/drivers/postgres" "github.com/neilotoole/sq/drivers/sqlite3" "github.com/neilotoole/sq/drivers/sqlserver" "github.com/neilotoole/sq/drivers/xlsx" - "github.com/neilotoole/sq/libsq/source" - "github.com/neilotoole/sq/testh" - - "github.com/stretchr/testify/require" - - "github.com/neilotoole/sq/cli/testrun" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" + "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/tu" diff --git a/cli/cmd_cache.go b/cli/cmd_cache.go index 129d725dd..9d61edbaa 100644 --- a/cli/cmd_cache.go +++ b/cli/cmd_cache.go @@ -1,7 +1,6 @@ package cli import ( - "github.com/neilotoole/sq/libsq/core/options" "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/flag" @@ -9,6 +8,7 @@ import ( "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/driver" ) @@ -29,12 +29,11 @@ func newCacheCmd() *cobra.Command { $ sq cache enable - $ sq cache enable @sakila - $ sq cache disable + # Disable cache for an individual source. + $ sq cache disable @sakila $ sq cache clear - $ sq cache clear @sakila # Print tree view of cache dir. @@ -95,7 +94,7 @@ func execCacheStat(cmd *cobra.Command, _ []string) error { } enabled := driver.OptIngestCache.Get(ru.Config.Options) - return ru.Writers.Config.CacheInfo(dir, enabled, size) + return ru.Writers.Config.CacheStat(dir, enabled, size) } func newCacheClearCmd() *cobra.Command { diff --git a/cli/cmd_diff.go b/cli/cmd_diff.go index ee0fe4b04..c1903760c 100644 --- a/cli/cmd_diff.go +++ b/cli/cmd_diff.go @@ -1,7 +1,6 @@ package cli import ( - "github.com/neilotoole/sq/libsq/driver" "github.com/samber/lo" "github.com/spf13/cobra" @@ -13,6 +12,7 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/stringz" + "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" ) diff --git a/cli/cmd_inspect.go b/cli/cmd_inspect.go index 712eb0a8a..96124b26b 100644 --- a/cli/cmd_inspect.go +++ b/cli/cmd_inspect.go @@ -5,8 +5,6 @@ import ( "database/sql" "slices" - "github.com/neilotoole/sq/libsq/driver" - "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/flag" @@ -14,6 +12,7 @@ import ( "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/metadata" ) diff --git a/cli/cmd_inspect_test.go b/cli/cmd_inspect_test.go index 7164f299a..d9e17fad7 100644 --- a/cli/cmd_inspect_test.go +++ b/cli/cmd_inspect_test.go @@ -7,9 +7,6 @@ import ( "os" "testing" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lgt" - "github.com/samber/lo" "github.com/stretchr/testify/require" @@ -20,6 +17,8 @@ import ( "github.com/neilotoole/sq/drivers/postgres" "github.com/neilotoole/sq/drivers/sqlite3" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/lg" + "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/libsq/source/metadata" @@ -96,7 +95,6 @@ func TestCmdInspect_json_yaml(t *testing.T) { //nolint:tparallel for _, tblName := range gotTableNames { tblName := tblName t.Run(tblName, func(t *testing.T) { - // t.Parallel() tu.SkipShort(t, true) tr2 := testrun.New(lg.NewContext(th.Context, lgt.New(t)), t, tr) err := tr2.Exec("inspect", "."+tblName, fmt.Sprintf("--%s", tf.format)) diff --git a/cli/cmd_src_test.go b/cli/cmd_src_test.go index 59ac08567..0e007335d 100644 --- a/cli/cmd_src_test.go +++ b/cli/cmd_src_test.go @@ -14,14 +14,13 @@ import ( func TestCmdSrc(t *testing.T) { ctx := context.Background() th := testh.New(t) - _ = th tr := testrun.New(ctx, t, nil).Add() - // err := tr.Exec("src") - // require.NoError(t, err) + err := tr.Exec("src") + require.NoError(t, err) tr.Reset().Add(*th.Source(sakila.CSVActor)) - err := tr.Exec("src") + err = tr.Exec("src") require.NoError(t, err) err = tr.Reset().Exec(".data | .[0:5]") diff --git a/cli/cmd_x.go b/cli/cmd_x.go index 2d4c08357..3a7b9b353 100644 --- a/cli/cmd_x.go +++ b/cli/cmd_x.go @@ -158,7 +158,6 @@ func execXDownloadCmd(cmd *cobra.Command, args []string) error { c := httpz.NewClient( httpz.DefaultUserAgent, httpz.OptResponseTimeout(time.Second*15), - // httpz.OptRequestTimeout(time.Second*2), httpz.OptRequestDelay(time.Second*5), ) dl, err := download.New(fakeSrc.Handle, c, u.String(), cacheDir) diff --git a/cli/config/yamlstore/yamlstore.go b/cli/config/yamlstore/yamlstore.go index 5a192aa99..0b8b5acb5 100644 --- a/cli/config/yamlstore/yamlstore.go +++ b/cli/config/yamlstore/yamlstore.go @@ -218,7 +218,6 @@ func (fs *Store) fileExists() bool { } // acquireLock acquires the config lock, and returns an unlock func. -// This is an internal convenience method. func (fs *Store) acquireLock(ctx context.Context) (unlock func(), err error) { lock, err := fs.Lockfile() if err != nil { diff --git a/cli/diff/data_naive.go b/cli/diff/data_naive.go index 324010145..2a21f1213 100644 --- a/cli/diff/data_naive.go +++ b/cli/diff/data_naive.go @@ -7,8 +7,6 @@ import ( "slices" "time" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/samber/lo" "golang.org/x/sync/errgroup" @@ -19,6 +17,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/driver" ) diff --git a/cli/diff/diff.go b/cli/diff/diff.go index 92e26f585..ce7afd6ac 100644 --- a/cli/diff/diff.go +++ b/cli/diff/diff.go @@ -14,10 +14,9 @@ import ( udiff "github.com/neilotoole/sq/cli/diff/internal/go-udiff" "github.com/neilotoole/sq/cli/diff/internal/go-udiff/myers" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/cli/output" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/metadata" diff --git a/cli/logging.go b/cli/logging.go index bee8bbb4d..48a2116c7 100644 --- a/cli/logging.go +++ b/cli/logging.go @@ -11,14 +11,13 @@ import ( "strconv" "strings" - "github.com/neilotoole/sq/libsq/core/ioz/httpz" - "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/output/format" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz/httpz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/devlog" "github.com/neilotoole/sq/libsq/core/lg/lga" @@ -114,7 +113,7 @@ func defaultLogging(ctx context.Context, osArgs []string, cfg *config.Config, } closer = logFile.Close - // Determine if we're logging dev mode (format.Text). + // Determine if we're logging in dev mode (format.Text). devMode := OptLogFormat.Default() != format.JSON switch getLogFormat(ctx, osArgs, cfg) { //nolint:exhaustive case format.Text: diff --git a/cli/output.go b/cli/output.go index b69928fb9..78675b420 100644 --- a/cli/output.go +++ b/cli/output.go @@ -27,7 +27,6 @@ import ( "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/stringz" @@ -467,11 +466,7 @@ func getPrinting(cmd *cobra.Command, clnup *cleanup.Cleanup, opts options.Option clnup.Add(pb.Stop) // On first write to stdout, we remove the progress widget. - out2 = ioz.NotifyOnceWriter(out2, func() { - lg.FromContext(ctx).Debug("Output stream is being written to; removing progress widget") - pb.Stop() - lg.FromContext(ctx).Debug("Progress widget should be removed now") - }) + out2 = ioz.NotifyOnceWriter(out2, pb.Stop) cmd.SetContext(progress.NewContext(ctx, pb)) } diff --git a/cli/output/jsonw/configwriter.go b/cli/output/jsonw/configwriter.go index 7692ebaa5..bd96b6b69 100644 --- a/cli/output/jsonw/configwriter.go +++ b/cli/output/jsonw/configwriter.go @@ -27,9 +27,8 @@ func (w *configWriter) CacheLocation(loc string) error { return writeJSON(w.out, w.pr, m) } -// CacheInfo implements output.ConfigWriter. It simply -// delegates to CacheLocation. -func (w *configWriter) CacheInfo(loc string, enabled bool, size int64) error { +// CacheStat implements output.ConfigWriter. +func (w *configWriter) CacheStat(loc string, enabled bool, size int64) error { type cacheInfo struct { Location string `json:"location"` Enabled bool `json:"enabled"` diff --git a/cli/output/tablew/configwriter.go b/cli/output/tablew/configwriter.go index 2035b164c..95b7334e8 100644 --- a/cli/output/tablew/configwriter.go +++ b/cli/output/tablew/configwriter.go @@ -45,9 +45,8 @@ func (w *configWriter) CacheLocation(loc string) error { return errz.Err(err) } -// CacheInfo implements output.ConfigWriter. It simply -// delegates to CacheLocation. -func (w *configWriter) CacheInfo(loc string, enabled bool, size int64) error { +// CacheStat implements output.ConfigWriter. +func (w *configWriter) CacheStat(loc string, enabled bool, size int64) error { const sp = " " s := loc + sp if enabled { diff --git a/cli/output/writers.go b/cli/output/writers.go index 1579c81ae..f58e5b088 100644 --- a/cli/output/writers.go +++ b/cli/output/writers.go @@ -146,9 +146,9 @@ type ConfigWriter interface { // CacheLocation prints the cache location. CacheLocation(loc string) error - // CacheInfo prints cache info. Set arg size to -1 to indicate + // CacheStat prints cache info. Set arg size to -1 to indicate // that the size of the cache could not be calculated. - CacheInfo(loc string, enabled bool, size int64) error + CacheStat(loc string, enabled bool, size int64) error } // Writers is a container for the various output Writers. diff --git a/cli/output/yamlw/configwriter.go b/cli/output/yamlw/configwriter.go index 80fc181e9..0b1874d2d 100644 --- a/cli/output/yamlw/configwriter.go +++ b/cli/output/yamlw/configwriter.go @@ -45,9 +45,8 @@ func (w *configWriter) CacheLocation(loc string) error { return writeYAML(w.out, w.p, m) } -// CacheInfo implements output.ConfigWriter. It simply -// delegates to CacheLocation. -func (w *configWriter) CacheInfo(loc string, enabled bool, size int64) error { +// CacheStat implements output.ConfigWriter. +func (w *configWriter) CacheStat(loc string, enabled bool, size int64) error { type cacheInfo struct { Location string `yaml:"location"` Enabled bool `yaml:"enabled"` diff --git a/cli/run.go b/cli/run.go index 0a9029037..319dee325 100644 --- a/cli/run.go +++ b/cli/run.go @@ -8,16 +8,12 @@ import ( "path/filepath" "time" - "github.com/neilotoole/sq/cli/flag" - - "github.com/neilotoole/sq/libsq/core/ioz/checksum" - "github.com/neilotoole/sq/libsq/core/ioz/lockfile" - "github.com/spf13/cobra" "github.com/neilotoole/sq/cli/config" "github.com/neilotoole/sq/cli/config/yamlstore" v0_34_0 "github.com/neilotoole/sq/cli/config/yamlstore/upgrades/v0.34.0" //nolint:revive + "github.com/neilotoole/sq/cli/flag" "github.com/neilotoole/sq/cli/run" "github.com/neilotoole/sq/drivers/csv" "github.com/neilotoole/sq/drivers/json" @@ -30,6 +26,8 @@ import ( "github.com/neilotoole/sq/drivers/xlsx" "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/slogbuf" @@ -90,7 +88,6 @@ func newRun(ctx context.Context, stdin *os.File, stdout, stderr io.Writer, args log, logHandler, logCloser, logErr := defaultLogging(ctx, args, ru.Config) ru.Cleanup = cleanup.New() - // FIXME: re-enable log closing ru.LogCloser = logCloser if logErr != nil { stderrLog, h := stderrLogger() @@ -277,8 +274,8 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { dr.AddProvider(xlsx.Type, &xlsx.Provider{Log: log, Ingester: ru.Grips, Files: ru.Files}) ru.Files.AddDriverDetectors(xlsx.DetectXLSX) // One day we may have more supported user driver genres. - userDriverImporters := map[string]userdriver.ImportFunc{ - xmlud.Genre: xmlud.Import, + userDriverImporters := map[string]userdriver.IngestFunc{ + xmlud.Genre: xmlud.Ingest, } for i, udd := range cfg.Ext.UserDrivers { @@ -303,7 +300,7 @@ func FinishRunInit(ctx context.Context, ru *run.Run) error { udp := &userdriver.Provider{ Log: log, DriverDef: udd, - ImportFn: importFn, + IngestFn: importFn, Ingester: ru.Grips, Files: ru.Files, } diff --git a/cli/terminal_windows.go b/cli/terminal_windows.go index 447683c4c..c3bafaf89 100644 --- a/cli/terminal_windows.go +++ b/cli/terminal_windows.go @@ -4,7 +4,6 @@ import ( "io" "os" - "golang.org/x/sys/windows" "golang.org/x/term" ) @@ -21,6 +20,10 @@ func isTerminal(w io.Writer) bool { // isColorTerminal returns true if w is a colorable terminal. // It respects [NO_COLOR], [FORCE_COLOR] and TERM=dumb environment variables. // +// Acknowledgement: This function is lifted from neilotoole/jsoncolor, but +// it was contributed by @hermannm. +// - https://github.com/neilotoole/jsoncolor/pull/27 +// // [NO_COLOR]: https://no-color.org/ // [FORCE_COLOR]: https://force-color.org/ func isColorTerminal(w io.Writer) bool { diff --git a/cli/testrun/testrun.go b/cli/testrun/testrun.go index fd1f58824..b99c87d89 100644 --- a/cli/testrun/testrun.go +++ b/cli/testrun/testrun.go @@ -12,8 +12,6 @@ import ( "sync" "testing" - "github.com/neilotoole/sq/testh" - "github.com/stretchr/testify/require" "github.com/neilotoole/sq/cli" @@ -25,6 +23,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lgt" "github.com/neilotoole/sq/libsq/core/options" "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/tu" ) @@ -121,6 +120,13 @@ func newRun(ctx context.Context, t testing.TB, // The Files instance needs unique dirs for temp and cache because // the test runs may execute in parallel inside the same test binary // process, thus breaking the pid-based lockfile mechanism. + + // If cacheDir was supplied, use that one, because it's probably the + // cache dir from a previous run, that we want to reuse. If not supplied, + // create a unique cache dir for this run. + // The Files instance generally needs unique dirs for temp and cache because + // the test runs may execute in parallel inside the same test binary + // process, thus breaking the pid-based lockfile mechanism. if cacheDir == "" { cacheDir = tu.CacheDir(t, false) } diff --git a/drivers/csv/csv.go b/drivers/csv/csv.go index 700399345..1664f7a8f 100644 --- a/drivers/csv/csv.go +++ b/drivers/csv/csv.go @@ -65,7 +65,7 @@ func (d *driveri) DriverMetadata() driver.Metadata { return md } -// Open implements driver.GripOpener. +// Open implements driver.Driver. func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { log := lg.FromContext(ctx) log.Debug(lgm.OpenSrc, lga.Src, src) diff --git a/drivers/csv/csv_test.go b/drivers/csv/csv_test.go index abae10c2e..2334dd102 100644 --- a/drivers/csv/csv_test.go +++ b/drivers/csv/csv_test.go @@ -9,12 +9,10 @@ import ( "testing" "time" - "github.com/neilotoole/sq/testh/fixt" - "golang.org/x/exp/maps" - "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" "github.com/neilotoole/sq/cli/testrun" "github.com/neilotoole/sq/drivers/csv" @@ -27,6 +25,7 @@ import ( "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" "github.com/neilotoole/sq/testh" + "github.com/neilotoole/sq/testh/fixt" "github.com/neilotoole/sq/testh/sakila" ) diff --git a/drivers/json/ingest.go b/drivers/json/ingest.go index e798bf4e5..e15356c3c 100644 --- a/drivers/json/ingest.go +++ b/drivers/json/ingest.go @@ -1,6 +1,6 @@ package json -// ingest.go contains functionality common to the +// xmlud.go contains functionality common to the // various JSON import mechanisms. import ( diff --git a/drivers/json/json.go b/drivers/json/json.go index a40d80310..44d0c4e3e 100644 --- a/drivers/json/json.go +++ b/drivers/json/json.go @@ -90,7 +90,7 @@ func (d *driveri) DriverMetadata() driver.Metadata { return md } -// Open implements driver.GripOpener. +// Open implements driver.Driver. func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { log := lg.FromContext(ctx) log.Debug(lgm.OpenSrc, lga.Src, src) diff --git a/drivers/mysql/metadata.go b/drivers/mysql/metadata.go index 7b726712d..10b8ca470 100644 --- a/drivers/mysql/metadata.go +++ b/drivers/mysql/metadata.go @@ -10,8 +10,6 @@ import ( "strings" "time" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/go-sql-driver/mysql" "github.com/samber/lo" "golang.org/x/sync/errgroup" @@ -22,6 +20,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/core/stringz" @@ -172,7 +171,7 @@ func getNewRecordFunc(rowMeta record.Meta) driver.NewRecordFunc { } // getTableMetadata gets the metadata for a single table. It is the -// implementation of driver.Grip.Table. +// implementation of Grip.TableMetadata. func getTableMetadata(ctx context.Context, db sqlz.DB, tblName string) (*metadata.Table, error) { query := `SELECT TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, TABLE_COMMENT, (DATA_LENGTH + INDEX_LENGTH) AS table_size, (SELECT COUNT(*) FROM ` + "`" + tblName + "`" + `) AS row_count diff --git a/drivers/mysql/mysql.go b/drivers/mysql/mysql.go index 6d772a5b0..345882e37 100644 --- a/drivers/mysql/mysql.go +++ b/drivers/mysql/mysql.go @@ -459,7 +459,7 @@ func (d *driveri) getTableRecordMeta(ctx context.Context, db sqlz.DB, tblName st return destCols, nil } -// Open implements driver.GripOpener. +// Open implements driver.Driver. func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) diff --git a/drivers/postgres/metadata.go b/drivers/postgres/metadata.go index d82b12f2b..66ed131ec 100644 --- a/drivers/postgres/metadata.go +++ b/drivers/postgres/metadata.go @@ -9,8 +9,6 @@ import ( "strconv" "strings" - "github.com/neilotoole/sq/libsq/core/progress" - "golang.org/x/sync/errgroup" "github.com/neilotoole/sq/libsq/core/errz" @@ -19,6 +17,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/core/stringz" diff --git a/drivers/postgres/postgres.go b/drivers/postgres/postgres.go index 916a5ffcd..f000f1551 100644 --- a/drivers/postgres/postgres.go +++ b/drivers/postgres/postgres.go @@ -133,7 +133,7 @@ func (d *driveri) Renderer() *render.Renderer { return r } -// Open implements driver.GripOpener. +// Open implements driver.Driver. func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) diff --git a/drivers/sqlite3/metadata.go b/drivers/sqlite3/metadata.go index 601dcf9da..11e904ad3 100644 --- a/drivers/sqlite3/metadata.go +++ b/drivers/sqlite3/metadata.go @@ -9,12 +9,11 @@ import ( "reflect" "strings" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/core/stringz" diff --git a/drivers/sqlite3/pragma.go b/drivers/sqlite3/pragma.go index 63b99e4c6..c383c82b3 100644 --- a/drivers/sqlite3/pragma.go +++ b/drivers/sqlite3/pragma.go @@ -6,11 +6,10 @@ import ( "fmt" "strings" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/sqlz" ) diff --git a/drivers/sqlite3/sqlite3.go b/drivers/sqlite3/sqlite3.go index eb30fccc9..83967762c 100644 --- a/drivers/sqlite3/sqlite3.go +++ b/drivers/sqlite3/sqlite3.go @@ -14,8 +14,6 @@ import ( "strings" "time" - "github.com/neilotoole/sq/libsq/core/ioz" - _ "github.com/mattn/go-sqlite3" // Import for side effect of loading the driver "github.com/shopspring/decimal" @@ -23,6 +21,7 @@ import ( "github.com/neilotoole/sq/libsq/ast" "github.com/neilotoole/sq/libsq/ast/render" "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/jointype" "github.com/neilotoole/sq/libsq/core/kind" "github.com/neilotoole/sq/libsq/core/lg" @@ -125,7 +124,7 @@ func (d *driveri) DriverMetadata() driver.Metadata { } } -// Open implements driver.GripOpener. +// Open implements driver.Driver. func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) diff --git a/drivers/sqlserver/grip.go b/drivers/sqlserver/grip.go index 159034a1d..962de3912 100644 --- a/drivers/sqlserver/grip.go +++ b/drivers/sqlserver/grip.go @@ -5,8 +5,6 @@ import ( "database/sql" "log/slog" - "github.com/neilotoole/sq/libsq/core/lg" - "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/progress" @@ -43,11 +41,7 @@ func (g *grip) Source() *source.Source { // TableMetadata implements driver.Grip. func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { bar := progress.FromContext(ctx).NewUnitCounter(g.Source().Handle+"."+tblName+": read schema", "item") - defer func() { - lg.FromContext(ctx).Warn("Before bar stop") - bar.Stop() - lg.FromContext(ctx).Warn("After bar stop") - }() + defer bar.Stop() ctx = progress.NewBarContext(ctx, bar) const query = `SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_TYPE diff --git a/drivers/sqlserver/metadata.go b/drivers/sqlserver/metadata.go index a11ec97de..3c98ca2c6 100644 --- a/drivers/sqlserver/metadata.go +++ b/drivers/sqlserver/metadata.go @@ -8,8 +8,6 @@ import ( "strconv" "strings" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/c2h5oh/datasize" "golang.org/x/sync/errgroup" @@ -18,6 +16,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/record" "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/driver" diff --git a/drivers/sqlserver/properties.go b/drivers/sqlserver/properties.go index 563ce957a..db98549f6 100644 --- a/drivers/sqlserver/properties.go +++ b/drivers/sqlserver/properties.go @@ -3,10 +3,9 @@ package sqlserver import ( "context" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/sqlz" ) diff --git a/drivers/sqlserver/sqlserver.go b/drivers/sqlserver/sqlserver.go index 3fdd0d12b..de76665b9 100644 --- a/drivers/sqlserver/sqlserver.go +++ b/drivers/sqlserver/sqlserver.go @@ -156,7 +156,7 @@ func (d *driveri) Renderer() *render.Renderer { return r } -// Open implements driver.GripOpener. +// Open implements driver.Driver. func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) diff --git a/drivers/userdriver/grip..go b/drivers/userdriver/grip..go new file mode 100644 index 000000000..4e658f2ff --- /dev/null +++ b/drivers/userdriver/grip..go @@ -0,0 +1,65 @@ +package userdriver + +import ( + "context" + "database/sql" + "log/slog" + + "github.com/neilotoole/sq/libsq/core/lg/lga" + "github.com/neilotoole/sq/libsq/core/lg/lgm" + "github.com/neilotoole/sq/libsq/driver" + "github.com/neilotoole/sq/libsq/source" + "github.com/neilotoole/sq/libsq/source/metadata" +) + +// grip implements driver.Grip. +type grip struct { + log *slog.Logger + src *source.Source + impl driver.Grip +} + +// DB implements driver.Grip. +func (g *grip) DB(ctx context.Context) (*sql.DB, error) { + return g.impl.DB(ctx) +} + +// SQLDriver implements driver.Grip. +func (g *grip) SQLDriver() driver.SQLDriver { + return g.impl.SQLDriver() +} + +// Source implements driver.Grip. +func (g *grip) Source() *source.Source { + return g.src +} + +// TableMetadata implements driver.Grip. +func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { + return g.impl.TableMetadata(ctx, tblName) +} + +// SourceMetadata implements driver.Grip. +func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { + meta, err := g.impl.SourceMetadata(ctx, noSchema) + if err != nil { + return nil, err + } + + meta.Handle = g.src.Handle + meta.Location = g.src.Location + meta.Name, err = source.LocationFileName(g.src) + if err != nil { + return nil, err + } + + meta.FQName = meta.Name + return meta, nil +} + +// Close implements driver.Grip. +func (g *grip) Close() error { + g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) + + return g.impl.Close() +} diff --git a/drivers/userdriver/userdriver.go b/drivers/userdriver/userdriver.go index 5bafb609c..27d76c812 100644 --- a/drivers/userdriver/userdriver.go +++ b/drivers/userdriver/userdriver.go @@ -7,7 +7,6 @@ package userdriver import ( "context" - "database/sql" "io" "log/slog" @@ -19,13 +18,11 @@ import ( "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" - "github.com/neilotoole/sq/libsq/source/metadata" ) -// ImportFunc is a function that can import +// IngestFunc is a function that can ingest // data (as defined in def) to destGrip. -type ImportFunc func(ctx context.Context, def *DriverDef, - data io.Reader, destGrip driver.Grip) error +type IngestFunc func(ctx context.Context, def *DriverDef, data io.Reader, destGrip driver.Grip) error // Provider implements driver.Provider for a DriverDef. type Provider struct { @@ -33,7 +30,7 @@ type Provider struct { DriverDef *DriverDef Ingester driver.GripOpenIngester Files *source.Files - ImportFn ImportFunc + IngestFn IngestFunc } // DriverFor implements driver.Provider. @@ -47,7 +44,7 @@ func (p *Provider) DriverFor(typ drivertype.Type) (driver.Driver, error) { typ: typ, def: p.DriverDef, ingester: p.Ingester, - ingestFn: p.ImportFn, + ingestFn: p.IngestFn, files: p.Files, }, nil } @@ -67,7 +64,7 @@ type driveri struct { def *DriverDef files *source.Files ingester driver.GripOpenIngester - ingestFn ImportFunc + ingestFn IngestFunc } // DriverMetadata implements driver.Driver. @@ -80,7 +77,7 @@ func (d *driveri) DriverMetadata() driver.Metadata { } } -// Open implements driver.GripOpener. +// Open implements driver.Driver. func (d *driveri) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { log := lg.FromContext(ctx).With(lga.Src, src) log.Debug(lgm.OpenSrc) @@ -121,55 +118,3 @@ func (d *driveri) ValidateSource(src *source.Source) (*source.Source, error) { func (d *driveri) Ping(ctx context.Context, src *source.Source) error { return d.files.Ping(ctx, src) } - -// grip implements driver.Grip. -type grip struct { - log *slog.Logger - src *source.Source - impl driver.Grip -} - -// DB implements driver.Grip. -func (g *grip) DB(ctx context.Context) (*sql.DB, error) { - return g.impl.DB(ctx) -} - -// SQLDriver implements driver.Grip. -func (g *grip) SQLDriver() driver.SQLDriver { - return g.impl.SQLDriver() -} - -// Source implements driver.Grip. -func (g *grip) Source() *source.Source { - return g.src -} - -// TableMetadata implements driver.Grip. -func (g *grip) TableMetadata(ctx context.Context, tblName string) (*metadata.Table, error) { - return g.impl.TableMetadata(ctx, tblName) -} - -// SourceMetadata implements driver.Grip. -func (g *grip) SourceMetadata(ctx context.Context, noSchema bool) (*metadata.Source, error) { - meta, err := g.impl.SourceMetadata(ctx, noSchema) - if err != nil { - return nil, err - } - - meta.Handle = g.src.Handle - meta.Location = g.src.Location - meta.Name, err = source.LocationFileName(g.src) - if err != nil { - return nil, err - } - - meta.FQName = meta.Name - return meta, nil -} - -// Close implements driver.Grip. -func (g *grip) Close() error { - g.log.Debug(lgm.CloseDB, lga.Handle, g.src.Handle) - - return g.impl.Close() -} diff --git a/drivers/userdriver/xmlud/xmlimport.go b/drivers/userdriver/xmlud/xmlud.go similarity index 78% rename from drivers/userdriver/xmlud/xmlimport.go rename to drivers/userdriver/xmlud/xmlud.go index e28883c53..1f0ba0aa2 100644 --- a/drivers/userdriver/xmlud/xmlimport.go +++ b/drivers/userdriver/xmlud/xmlud.go @@ -14,8 +14,6 @@ import ( "strconv" "strings" - "github.com/neilotoole/sq/libsq/core/sqlz" - "github.com/neilotoole/sq/drivers/userdriver" "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" @@ -23,6 +21,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/sqlmodel" + "github.com/neilotoole/sq/libsq/core/sqlz" "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" ) @@ -30,10 +29,10 @@ import ( // Genre is the user driver genre that this package supports. const Genre = "xml" -// Import implements userdriver.ImportFunc. -func Import(ctx context.Context, def *userdriver.DriverDef, data io.Reader, destGrip driver.Grip) error { +// Ingest implements userdriver.IngestFunc. +func Ingest(ctx context.Context, def *userdriver.DriverDef, data io.Reader, destGrip driver.Grip) error { if def.Genre != Genre { - return errz.Errorf("xmlud.Import does not support genre {%s}", def.Genre) + return errz.Errorf("xmlud.Ingest does not support genre {%s}", def.Genre) } log := lg.FromContext(ctx) @@ -42,7 +41,7 @@ func Import(ctx context.Context, def *userdriver.DriverDef, data io.Reader, dest return err } - im := &importer{ + ing := &ingester{ log: log, destGrip: destGrip, destDB: db, @@ -58,16 +57,16 @@ func Import(ctx context.Context, def *userdriver.DriverDef, data io.Reader, dest msgOnce: map[string]struct{}{}, } - if err = im.execIngest(ctx); err != nil { - lg.WarnIfFuncError(log, "xml ingest: cleanup", im.clnup.Run) + if err = ing.execIngest(ctx); err != nil { + lg.WarnIfFuncError(log, "xml ingest: cleanup", ing.clnup.Run) return errz.Wrap(err, "xml ingest") } - return errz.Wrap(im.clnup.Run(), "xml ingest: cleanup") + return errz.Wrap(ing.clnup.Run(), "xml ingest: cleanup") } -// importer does the work of importing data from XML. -type importer struct { +// ingester does the work of importing data from XML. +type ingester struct { log *slog.Logger def *userdriver.DriverDef data io.Reader @@ -91,7 +90,7 @@ type importer struct { // update's WHERE clause. execUpdateFns map[string]func(ctx context.Context, updateVals, whereArgs []any) error - // clnup holds cleanup funcs that should be run when the importer + // clnup holds cleanup funcs that should be run when the ingester // finishes. clnup *cleanup.Cleanup @@ -99,13 +98,13 @@ type importer struct { msgOnce map[string]struct{} } -func (im *importer) execIngest(ctx context.Context) error { //nolint:gocognit - err := im.createTables(ctx) +func (in *ingester) execIngest(ctx context.Context) error { //nolint:gocognit + err := in.createTables(ctx) if err != nil { return err } - decoder := xml.NewDecoder(im.data) + decoder := xml.NewDecoder(in.data) for { t, err := decoder.Token() if t == nil { @@ -117,33 +116,33 @@ func (im *importer) execIngest(ctx context.Context) error { //nolint:gocognit switch elem := t.(type) { case xml.StartElement: - im.selStack.push(elem.Name.Local) - if im.isRootSelector() { + in.selStack.push(elem.Name.Local) + if in.isRootSelector() { continue } - if im.isRowSelector() { + if in.isRowSelector() { // We found a new row... - prevRow := im.rowStack.peek() + prevRow := in.rowStack.peek() if prevRow != nil { // Because the new row might require the primary key of the prev row, // we need to save the previous row, to ensure its primary key is // generated. - err = im.saveRow(ctx, prevRow) + err = in.saveRow(ctx, prevRow) if err != nil { return err } } var curRow *rowState - curRow, err = im.buildRow() + curRow, err = in.buildRow() if err != nil { return err } - im.rowStack.push(curRow) + in.rowStack.push(curRow) - err = im.handleElemAttrs(elem, curRow) + err = in.handleElemAttrs(elem, curRow) if err != nil { return err } @@ -152,43 +151,43 @@ func (im *importer) execIngest(ctx context.Context) error { //nolint:gocognit } // It's not a row element, it's a col element - curRow := im.rowStack.peek() + curRow := in.rowStack.peek() if curRow == nil { return errz.Errorf("unable to parse XML: no current row on stack for elem {%s}", elem.Name.Local) } - col := curRow.tbl.ColBySelector(im.selStack.selector()) + col := curRow.tbl.ColBySelector(in.selStack.selector()) if col == nil { - if msg, ok := im.msgOncef("Skip: element {%s} is not a column of table {%s}", elem.Name.Local, + if msg, ok := in.msgOncef("Skip: element {%s} is not a column of table {%s}", elem.Name.Local, curRow.tbl.Name); ok { - im.log.Debug(msg) + in.log.Debug(msg) } continue } curRow.curCol = col - err = im.handleElemAttrs(elem, curRow) + err = in.handleElemAttrs(elem, curRow) if err != nil { return err } case xml.EndElement: - if im.isRowSelector() { - row := im.rowStack.peek() + if in.isRowSelector() { + row := in.rowStack.peek() if row.dirty() { - err = im.saveRow(ctx, row) + err = in.saveRow(ctx, row) if err != nil { return err } } - im.rowStack.pop() + in.rowStack.pop() } - im.selStack.pop() + in.selStack.pop() case xml.CharData: data := string(elem) - curRow := im.rowStack.peek() + curRow := in.rowStack.peek() if curRow == nil { continue @@ -198,7 +197,7 @@ func (im *importer) execIngest(ctx context.Context) error { //nolint:gocognit continue } - val, err := im.convertVal(curRow.tbl.Name, curRow.curCol, data) + val, err := in.convertVal(curRow.tbl.Name, curRow.curCol, data) if err != nil { return err } @@ -211,7 +210,7 @@ func (im *importer) execIngest(ctx context.Context) error { //nolint:gocognit return nil } -func (im *importer) convertVal(tbl string, col *userdriver.ColMapping, data any) (any, error) { +func (in *ingester) convertVal(tbl string, col *userdriver.ColMapping, data any) (any, error) { const errTpl = `conversion error: %s.%s: expected "%s" but got %T(%v)` const errTplMsg = `conversion error: %s.%s: expected "%s" but got %T(%v): %v` @@ -275,22 +274,22 @@ func (im *importer) convertVal(tbl string, col *userdriver.ColMapping, data any) } } -func (im *importer) handleElemAttrs(elem xml.StartElement, curRow *rowState) error { +func (in *ingester) handleElemAttrs(elem xml.StartElement, curRow *rowState) error { if len(elem.Attr) > 0 { - baseSel := im.selStack.selector() + baseSel := in.selStack.selector() for _, attr := range elem.Attr { attrSel := baseSel + "/@" + attr.Name.Local attrCol := curRow.tbl.ColBySelector(attrSel) if attrCol == nil { - if msg, ok := im.msgOncef("Skip: attr {%s} is not a column of table {%s}", attrSel, curRow.tbl.Name); ok { - im.log.Debug(msg) + if msg, ok := in.msgOncef("Skip: attr {%s} is not a column of table {%s}", attrSel, curRow.tbl.Name); ok { + in.log.Debug(msg) } continue } // We have found the col matching the attribute - val, err := im.convertVal(curRow.tbl.Name, attrCol, attr.Value) + val, err := in.convertVal(curRow.tbl.Name, attrCol, attr.Value) if err != nil { return err } @@ -304,7 +303,7 @@ func (im *importer) handleElemAttrs(elem xml.StartElement, curRow *rowState) err // setForeignColsVals sets the values of any column that needs to be // populated from a foreign key. -func (im *importer) setForeignColsVals(row *rowState) error { +func (in *ingester) setForeignColsVals(row *rowState) error { // check if we need to populate any of the row's values with // foreign key data (e.g. from parent table). for _, col := range row.tbl.Cols { @@ -321,7 +320,7 @@ func (im *importer) setForeignColsVals(row *rowState) error { fkName := parts[1] - parentRow := im.rowStack.peekN(1) + parentRow := in.rowStack.peekN(1) if parentRow == nil { return errz.Errorf("unable to find parent() table for foreign key for %s.%s", row.tbl.Name, col.Name) } @@ -337,7 +336,7 @@ func (im *importer) setForeignColsVals(row *rowState) error { return nil } -func (im *importer) setSequenceColsVals(row *rowState, nextSeqVal int64) { +func (in *ingester) setSequenceColsVals(row *rowState, nextSeqVal int64) { seqColNames := userdriver.NamesFromCols(row.tbl.SequenceCols()) for _, seqColName := range seqColNames { if _, ok := row.savedColVals[seqColName]; ok { @@ -360,7 +359,7 @@ func (im *importer) setSequenceColsVals(row *rowState, nextSeqVal int64) { // Probably safer to override the value. row.dirtyColVals[seqColName] = nextSeqVal - im.log.Warn("%s.%s is a auto-generated sequence() column: ignoring the value found in input", + in.log.Warn("%s.%s is a auto-generated sequence() column: ignoring the value found in input", row.tbl.Name, seqColName) continue } @@ -370,19 +369,19 @@ func (im *importer) setSequenceColsVals(row *rowState, nextSeqVal int64) { } } -func (im *importer) saveRow(ctx context.Context, row *rowState) error { +func (in *ingester) saveRow(ctx context.Context, row *rowState) error { if !row.dirty() { return nil } - tblDef, ok := im.tblDefs[row.tbl.Name] + tblDef, ok := in.tblDefs[row.tbl.Name] if !ok { return errz.Errorf("unable to find definition for table {%s}", row.tbl.Name) } if row.created() { // Row already exists in the db - err := im.dbUpdate(ctx, row) + err := in.dbUpdate(ctx, row) if err != nil { return errz.Wrapf(err, "failed to update table {%s}", tblDef.Name) } @@ -395,14 +394,14 @@ func (im *importer) saveRow(ctx context.Context, row *rowState) error { // Maintain the table's sequence. Note that we always increment the // seq val even if there are no sequence cols for this table. - prevSeqVal := im.tblSequence[tblDef.Name] + prevSeqVal := in.tblSequence[tblDef.Name] nextSeqVal := prevSeqVal + 1 - im.tblSequence[tblDef.Name] = nextSeqVal + in.tblSequence[tblDef.Name] = nextSeqVal - im.setSequenceColsVals(row, nextSeqVal) + in.setSequenceColsVals(row, nextSeqVal) // Set any foreign cols - err := im.setForeignColsVals(row) + err := in.setForeignColsVals(row) if err != nil { return err } @@ -414,7 +413,7 @@ func (im *importer) saveRow(ctx context.Context, row *rowState) error { } } - err = im.dbInsert(ctx, row) + err = in.dbInsert(ctx, row) if err != nil { return errz.Wrapf(err, "failed to insert to table {%s}", tblDef.Name) } @@ -424,7 +423,7 @@ func (im *importer) saveRow(ctx context.Context, row *rowState) error { } // dbInsert inserts row's dirty col values to row's table. -func (im *importer) dbInsert(ctx context.Context, row *rowState) error { +func (in *ingester) dbInsert(ctx context.Context, row *rowState) error { tblName := row.tbl.Name colNames := make([]string, len(row.dirtyColVals)) vals := make([]any, len(row.dirtyColVals)) @@ -438,16 +437,16 @@ func (im *importer) dbInsert(ctx context.Context, row *rowState) error { // We cache the prepared insert statements. cacheKey := "##insert_func__" + tblName + "__" + strings.Join(colNames, ",") - execInsertFn, ok := im.execInsertFns[cacheKey] + execInsertFn, ok := in.execInsertFns[cacheKey] if !ok { // Nothing cached, prepare the insert statement and insert munge func - stmtExecer, err := im.destGrip.SQLDriver().PrepareInsertStmt(ctx, im.destDB, tblName, colNames, 1) + stmtExecer, err := in.destGrip.SQLDriver().PrepareInsertStmt(ctx, in.destDB, tblName, colNames, 1) if err != nil { return err } // Make sure we close stmt eventually. - im.clnup.AddC(stmtExecer) + in.clnup.AddC(stmtExecer) execInsertFn = func(ctx context.Context, vals []any) error { // Munge vals so that they're as the target DB expects @@ -461,7 +460,7 @@ func (im *importer) dbInsert(ctx context.Context, row *rowState) error { } // Cache the execInsertFn. - im.execInsertFns[cacheKey] = execInsertFn + in.execInsertFns[cacheKey] = execInsertFn } err := execInsertFn(ctx, vals) @@ -474,8 +473,8 @@ func (im *importer) dbInsert(ctx context.Context, row *rowState) error { // dbUpdate updates row's table with row's dirty values, using row's // primary key cols as the args to the WHERE clause. -func (im *importer) dbUpdate(ctx context.Context, row *rowState) error { - drvr := im.destGrip.SQLDriver() +func (in *ingester) dbUpdate(ctx context.Context, row *rowState) error { + drvr := in.destGrip.SQLDriver() tblName := row.tbl.Name pkColNames := row.tbl.PrimaryKey @@ -510,16 +509,16 @@ func (im *importer) dbUpdate(ctx context.Context, row *rowState) error { // We cache the prepared statement. cacheKey := "##update_func__" + tblName + "__" + strings.Join(colNames, ",") + whereClause - execUpdateFn, ok := im.execUpdateFns[cacheKey] + execUpdateFn, ok := in.execUpdateFns[cacheKey] if !ok { // Nothing cached, prepare the update statement and munge func - stmtExecer, err := drvr.PrepareUpdateStmt(ctx, im.destDB, tblName, colNames, whereClause) + stmtExecer, err := drvr.PrepareUpdateStmt(ctx, in.destDB, tblName, colNames, whereClause) if err != nil { return err } // Make sure we close stmt eventually. - im.clnup.AddC(stmtExecer) + in.clnup.AddC(stmtExecer) execUpdateFn = func(ctx context.Context, updateVals, whereArgs []any) error { // Munge vals so that they're as the target DB expects @@ -535,7 +534,7 @@ func (im *importer) dbUpdate(ctx context.Context, row *rowState) error { } // Cache the execInsertFn. - im.execUpdateFns[cacheKey] = execUpdateFn + in.execUpdateFns[cacheKey] = execUpdateFn } err := execUpdateFn(ctx, dirtyVals, pkVals) @@ -546,10 +545,10 @@ func (im *importer) dbUpdate(ctx context.Context, row *rowState) error { return nil } -func (im *importer) buildRow() (*rowState, error) { - tbl := im.def.TableBySelector(im.selStack.selector()) +func (in *ingester) buildRow() (*rowState, error) { + tbl := in.def.TableBySelector(in.selStack.selector()) if tbl == nil { - return nil, errz.Errorf("no tbl matching current selector: %s", im.selStack.selector()) + return nil, errz.Errorf("no tbl matching current selector: %s", in.selStack.selector()) } r := &rowState{tbl: tbl} @@ -568,47 +567,47 @@ func (im *importer) buildRow() (*rowState, error) { return r, nil } -func (im *importer) createTables(ctx context.Context) error { - for i := range im.def.Tables { - tblDef, err := userdriver.ToTableDef(im.def.Tables[i]) +func (in *ingester) createTables(ctx context.Context) error { + for i := range in.def.Tables { + tblDef, err := userdriver.ToTableDef(in.def.Tables[i]) if err != nil { return err } - im.tblDefs[tblDef.Name] = tblDef + in.tblDefs[tblDef.Name] = tblDef - err = im.destGrip.SQLDriver().CreateTable(ctx, im.destDB, tblDef) + err = in.destGrip.SQLDriver().CreateTable(ctx, in.destDB, tblDef) if err != nil { return err } - im.log.Debug("Created table", lga.Target, source.Target(im.destGrip.Source(), tblDef.Name)) + in.log.Debug("Created table", lga.Target, source.Target(in.destGrip.Source(), tblDef.Name)) } return nil } // isRootSelector returns true if the current selector matches the root selector. -func (im *importer) isRootSelector() bool { - return im.selStack.selector() == im.def.Selector +func (in *ingester) isRootSelector() bool { + return in.selStack.selector() == in.def.Selector } // isRowSelector returns true if entity referred to by the current selector // maps to a table row (as opposed to a column). -func (im *importer) isRowSelector() bool { - return im.def.TableBySelector(im.selStack.selector()) != nil +func (in *ingester) isRowSelector() bool { + return in.def.TableBySelector(in.selStack.selector()) != nil } // msgOncef is used to prevent repeated logging of a message. The // method returns ok=true and the formatted string if the formatted // string has not been previous seen by msgOncef. -func (im *importer) msgOncef(format string, a ...any) (msg string, ok bool) { +func (in *ingester) msgOncef(format string, a ...any) (msg string, ok bool) { msg = fmt.Sprintf(format, a...) - if _, exists := im.msgOnce[msg]; exists { + if _, exists := in.msgOnce[msg]; exists { // msg already seen, return ok=false. return "", false } - im.msgOnce[msg] = struct{}{} + in.msgOnce[msg] = struct{}{} return msg, true } diff --git a/drivers/userdriver/xmlud/xmlimport_test.go b/drivers/userdriver/xmlud/xmlud_test.go similarity index 96% rename from drivers/userdriver/xmlud/xmlimport_test.go rename to drivers/userdriver/xmlud/xmlud_test.go index ee816f29f..0b641a935 100644 --- a/drivers/userdriver/xmlud/xmlimport_test.go +++ b/drivers/userdriver/xmlud/xmlud_test.go @@ -4,8 +4,6 @@ import ( "bytes" "testing" - "github.com/neilotoole/sq/testh/tu" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,6 +14,7 @@ import ( "github.com/neilotoole/sq/testh" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/testsrc" + "github.com/neilotoole/sq/testh/tu" ) const ( @@ -46,7 +45,7 @@ func TestImport_Ppl(t *testing.T) { tu.OpenFileCount(t, true) data := proj.ReadFile("drivers/userdriver/xmlud/testdata/people.xml") - err = xmlud.Import(th.Context, udDef, bytes.NewReader(data), grip) + err = xmlud.Ingest(th.Context, udDef, bytes.NewReader(data), grip) require.NoError(t, err) tu.OpenFileCount(t, true) @@ -93,7 +92,7 @@ func TestImport_RSS(t *testing.T) { }) data := proj.ReadFile("drivers/userdriver/xmlud/testdata/nytimes_local.rss.xml") - err = xmlud.Import(th.Context, udDef, bytes.NewReader(data), scratchDB) + err = xmlud.Ingest(th.Context, udDef, bytes.NewReader(data), scratchDB) require.NoError(t, err) srcMeta, err := scratchDB.SourceMetadata(th.Context, false) diff --git a/drivers/xlsx/grip.go b/drivers/xlsx/grip.go index 074a1616e..0e5704cc7 100644 --- a/drivers/xlsx/grip.go +++ b/drivers/xlsx/grip.go @@ -15,9 +15,7 @@ import ( // grip implements driver.Grip. It implements a deferred ingest // of the Excel data. type grip struct { - // REVISIT: do we need grip.log, or can we use lg.FromContext? - log *slog.Logger - + log *slog.Logger src *source.Source files *source.Files dbGrip driver.Grip diff --git a/drivers/xlsx/ingest.go b/drivers/xlsx/ingest.go index 42c411959..778565274 100644 --- a/drivers/xlsx/ingest.go +++ b/drivers/xlsx/ingest.go @@ -188,7 +188,7 @@ func ingestSheetToTable(ctx context.Context, destGrip driver.Grip, sheetTbl *she batchSize := driver.MaxBatchRows(drvr, len(destColKinds)) bi, err := driver.NewBatchInsert( ctx, - "Ingest "+sheet.name, + fmt.Sprintf("Ingest {%s}", sheet.name), drvr, conn, tblDef.Name, diff --git a/drivers/xlsx/xlsx.go b/drivers/xlsx/xlsx.go index fb9a91da1..bf94adb41 100644 --- a/drivers/xlsx/xlsx.go +++ b/drivers/xlsx/xlsx.go @@ -57,7 +57,7 @@ func (d *Driver) DriverMetadata() driver.Metadata { } } -// Open implements driver.GripOpener. +// Open implements driver.Driver. func (d *Driver) Open(ctx context.Context, src *source.Source) (driver.Grip, error) { log := lg.FromContext(ctx).With(lga.Src, src) log.Debug(lgm.OpenSrc, lga.Src, src) diff --git a/drivers/xlsx/xlsx_test.go b/drivers/xlsx/xlsx_test.go index 765d81036..da416cec6 100644 --- a/drivers/xlsx/xlsx_test.go +++ b/drivers/xlsx/xlsx_test.go @@ -7,8 +7,6 @@ import ( "testing" "time" - "github.com/neilotoole/sq/testh/fixt" - "github.com/samber/lo" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" @@ -26,6 +24,7 @@ import ( "github.com/neilotoole/sq/libsq/driver" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/testh" + "github.com/neilotoole/sq/testh/fixt" "github.com/neilotoole/sq/testh/proj" "github.com/neilotoole/sq/testh/sakila" "github.com/neilotoole/sq/testh/tu" diff --git a/libsq/ast/handle.go b/libsq/ast/handle.go index 77e9744a8..682ad07ec 100644 --- a/libsq/ast/handle.go +++ b/libsq/ast/handle.go @@ -3,9 +3,10 @@ package ast import ( "slices" + "github.com/samber/lo" + "github.com/neilotoole/sq/libsq/ast/internal/slq" "github.com/neilotoole/sq/libsq/core/tablefq" - "github.com/samber/lo" ) // HandleNode models a source handle such as "@sakila". diff --git a/libsq/core/ioz/checksum/checksum.go b/libsq/core/ioz/checksum/checksum.go index 34be28855..80cb9c4c3 100644 --- a/libsq/core/ioz/checksum/checksum.go +++ b/libsq/core/ioz/checksum/checksum.go @@ -1,3 +1,6 @@ +// Package checksum provides functions for working with checksums. +// It uses crc32 for the checksum algorithm, resulting in checksum +// values like "3af3aaad". package checksum import ( @@ -12,9 +15,8 @@ import ( "strconv" "strings" - "github.com/neilotoole/sq/libsq/core/ioz" - "github.com/neilotoole/sq/libsq/core/errz" + "github.com/neilotoole/sq/libsq/core/ioz" ) // Sum returns the hash of b as a hex string. @@ -178,6 +180,8 @@ func ForHTTPHeader(u string, header http.Header) Checksum { // both compressed and uncompressed responses. // // Our hack for now it to trim the "-df" suffix from the Etag. +// +// REVISIT: ForHTTPResponse is no longer used. It should be removed. func ForHTTPResponse(resp *http.Response) Checksum { if resp == nil { return "" diff --git a/libsq/core/ioz/checksum/checksum_test.go b/libsq/core/ioz/checksum/checksum_test.go index c5a3fdfd5..8cb2f7278 100644 --- a/libsq/core/ioz/checksum/checksum_test.go +++ b/libsq/core/ioz/checksum/checksum_test.go @@ -9,7 +9,7 @@ import ( "github.com/neilotoole/sq/libsq/core/ioz/checksum" ) -func TestHash(t *testing.T) { +func TestSum(t *testing.T) { got := checksum.Sum(nil) require.Equal(t, "", got) got = checksum.Sum([]byte{}) diff --git a/libsq/core/ioz/contextio/contextio.go b/libsq/core/ioz/contextio/contextio.go index b5555e241..25d2e2ad9 100644 --- a/libsq/core/ioz/contextio/contextio.go +++ b/libsq/core/ioz/contextio/contextio.go @@ -16,6 +16,7 @@ limitations under the License. // This code is lifted from github.com/dolmen-go/contextio. +// Package contextio provides io decorators that are context-aware. package contextio import ( diff --git a/libsq/core/ioz/download/download.go b/libsq/core/ioz/download/download.go index d4b0244a4..256a64a6c 100644 --- a/libsq/core/ioz/download/download.go +++ b/libsq/core/ioz/download/download.go @@ -1,5 +1,5 @@ // Package download provides a mechanism for getting files from -// HTTP URLs, making use of a mostly RFC-compliant cache. +// HTTP/S URLs, making use of a mostly RFC-compliant cache. // // Acknowledgement: This package is a heavily customized fork // of https://github.com/gregjones/httpcache, via bitcomplete/download. @@ -404,7 +404,7 @@ func (dl *Download) state(req *http.Request) State { } // Filesize returns the size of the downloaded file. This should -// be invoked after the download has completed. +// only be invoked after the download has completed. func (dl *Download) Filesize(ctx context.Context) (int64, error) { dl.mu.Lock() defer dl.mu.Unlock() @@ -437,8 +437,8 @@ func (dl *Download) Filesize(ctx context.Context) (int64, error) { return fi.Size(), nil } -// CacheFile returns the path to the cached file, if it exists, -// and has been fully downloaded. +// CacheFile returns the path to the cached file, if it exists and has +// been fully downloaded. func (dl *Download) CacheFile(ctx context.Context) (fp string, err error) { dl.mu.Lock() defer dl.mu.Unlock() diff --git a/libsq/core/ioz/download/handler.go b/libsq/core/ioz/download/handler.go index 265e6df63..a31cabfb3 100644 --- a/libsq/core/ioz/download/handler.go +++ b/libsq/core/ioz/download/handler.go @@ -19,7 +19,7 @@ type Handler struct { // Uncached is invoked when the download is not cached. The handler should // return an ioz.WriteErrorCloser, which the download contents will be written // to (as well as being written to the disk cache). On success, the dest - // io.WriteCloser is closed. If an error occurs during download or writing, + // writer is closed. If an error occurs during download or writing, // WriteErrorCloser.Error is invoked (but Close is not invoked). If the // handler returns a nil dest, the Download will log a warning and return. Uncached func() (dest ioz.WriteErrorCloser) diff --git a/libsq/core/ioz/httpz/httpz.go b/libsq/core/ioz/httpz/httpz.go index e1502d3ab..e614565d1 100644 --- a/libsq/core/ioz/httpz/httpz.go +++ b/libsq/core/ioz/httpz/httpz.go @@ -90,7 +90,7 @@ func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } -// ResponseLogValue implements slog.LogValuer for resp. +// ResponseLogValue implements slog.LogValuer for http.Response. func ResponseLogValue(resp *http.Response) slog.Value { if resp == nil { return slog.Value{} @@ -126,7 +126,7 @@ func ResponseLogValue(resp *http.Response) slog.Value { return slog.GroupValue(attrs...) } -// RequestLogValue implements slog.LogValuer for req. +// RequestLogValue implements slog.LogValuer for http.Request. func RequestLogValue(req *http.Request) slog.Value { if req == nil { return slog.Value{} diff --git a/libsq/core/ioz/httpz/opts.go b/libsq/core/ioz/httpz/opts.go index 53e85e820..f31b77549 100644 --- a/libsq/core/ioz/httpz/opts.go +++ b/libsq/core/ioz/httpz/opts.go @@ -59,20 +59,6 @@ func OptUserAgent(ua string) TripFunc { } } -// contextCause is a TripFunc that extracts the context.Cause error -// from the request context, if any, and returns it as the error. -func contextCause() TripFunc { - return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { - resp, err := next.RoundTrip(req) - if err != nil { - if cause := context.Cause(req.Context()); cause != nil { - err = cause - } - } - return resp, err - } -} - // DefaultUserAgent is the default User-Agent header value, // as used by [NewDefaultClient]. var DefaultUserAgent = OptUserAgent(buildinfo.Get().UserAgent()) @@ -222,3 +208,17 @@ func OptRequestDelay(delay time.Duration) TripFunc { return next.RoundTrip(req) } } + +// contextCause returns a TripFunc that extracts the context.Cause error +// from the request context, if any, and returns it as the error. +func contextCause() TripFunc { + return func(next http.RoundTripper, req *http.Request) (*http.Response, error) { + resp, err := next.RoundTrip(req) + if err != nil { + if cause := context.Cause(req.Context()); cause != nil { + err = cause + } + } + return resp, err + } +} diff --git a/libsq/core/ioz/lockfile/lockfile.go b/libsq/core/ioz/lockfile/lockfile.go index 78921731f..8f5e14ab1 100644 --- a/libsq/core/ioz/lockfile/lockfile.go +++ b/libsq/core/ioz/lockfile/lockfile.go @@ -29,10 +29,9 @@ func New(fp string) (Lockfile, error) { return Lockfile(lf), nil } -// Lock attempts to acquire the lockfile, retrying if necessary, -// until the timeout expires. If timeout is zero, retry will not occur. -// On success, nil is returned. An error is returned if the lock cannot -// be acquired for any reason. +// Lock attempts to acquire the lock, retrying if necessary, until the timeout +// expires. If timeout is zero, retry will not occur. On success, nil is +// returned. An error is returned if the lock cannot be acquired for any reason. func (l Lockfile) Lock(ctx context.Context, timeout time.Duration) error { log := lg.FromContext(ctx).With(lga.Lock, l, lga.Timeout, timeout) diff --git a/libsq/core/loz/loz.go b/libsq/core/loz/loz.go index 3109d3738..6265c0f91 100644 --- a/libsq/core/loz/loz.go +++ b/libsq/core/loz/loz.go @@ -155,7 +155,7 @@ func ZeroIfNil[T comparable](t *T) T { } // Take returns true if ch is non-nil and a value is available -// from ch, or false otherwise. This is useful in for a succinct +// from ch, or false otherwise. This is useful for a succinct // "if done" idiom, e.g.: // // if someCondition && loz.Take(doneCh) { diff --git a/libsq/core/progress/progress.go b/libsq/core/progress/progress.go index 2775d9c1f..9f433410e 100644 --- a/libsq/core/progress/progress.go +++ b/libsq/core/progress/progress.go @@ -23,7 +23,6 @@ import ( "time" "github.com/samber/lo" - mpb "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" @@ -75,9 +74,8 @@ func FromContext(ctx context.Context) *Progress { type barCtxKey struct{} -// NewBarContext returns ctx with bar added as a value. This -// context can be used in conjunction with progress.Incr to increment -// the progress bar. +// NewBarContext returns ctx with bar added as a value. This context can +// be used in conjunction with progress.Incr to increment the progress bar. func NewBarContext(ctx context.Context, bar *Bar) context.Context { if ctx == nil { ctx = context.Background() @@ -91,9 +89,10 @@ func NewBarContext(ctx context.Context, bar *Bar) context.Context { // It safe to invoke Incr on a nil context or a context that doesn't // contain a Bar. // -// NOTE: This is a bit of an experiment. I'm a bit hesitant in going even -// further with context-based logic, as it's not clear to me that it's -// a good path to be on. So, it's possible this mechanism may be removed. +// NOTE: This context-based incrementing is a bit of an experiment. I'm +// a bit hesitant in going even further with context-based logic, as it's not +// clear to me that it's a good idea to lean on context so much. +// So, it's possible this mechanism may be removed in the future. func Incr(ctx context.Context, n int) { if ctx == nil { return @@ -131,7 +130,6 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors delay: delay, stoppedCh: make(chan struct{}), stopOnce: &sync.Once{}, - // refreshCh: make(chan any, 100), } // Note that p.ctx is not the same as the arg ctx. This is a bit of a hack @@ -153,29 +151,11 @@ func New(ctx context.Context, out io.Writer, delay time.Duration, colors *Colors opts := []mpb.ContainerOption{ mpb.WithOutput(out), mpb.WithWidth(boxWidth), - - // FIXME: switch back to auto refresh? - // mpb.WithRefreshRate(refreshRate), - // mpb.WithManualRefresh(p.refreshCh), mpb.WithAutoRefresh(), // Needed for color in Windows, apparently } p.pc = mpb.NewWithContext(ctx, opts...) p.pcInitFn = nil - // go func() { - // for { - // select { - // case <-p.stoppedCh: - // return - // case <-p.ctx.Done(): - // return - // case <-p.refreshCh: - // default: - // // p.refreshCh <- time.Now() - // time.Sleep(refreshRate) - // } - // } - // }() } // REVISIT: The delay init of pc is no longer required. We can just @@ -215,8 +195,6 @@ type Progress struct { stoppedCh chan struct{} stopOnce *sync.Once - // refreshCh chan any - ctx context.Context cancelFn context.CancelFunc @@ -292,10 +270,6 @@ func (p *Progress) doStop() { <-b.barStoppedCh // Wait for bar to stop } - // p.refreshCh <- time.Now() - - // close(p.stoppedCh) - // So, now we REALLY want to wait for the progress widget // to finish. Alas, the pc.Wait method doesn't seem to // always remove the bars from the terminal. So, we do @@ -308,12 +282,7 @@ func (p *Progress) doStop() { <-p.ctx.Done() // We shouldn't need this extra call to pc.Wait, // but it shouldn't hurt? - // time.Sleep(time.Millisecond) // FIXME: delete p.pc.Wait() - - // And a tiny sleep, which again, hopefully can be removed - // at some point. - // time.Sleep(time.Millisecond) // FIXME: delete close(p.stoppedCh) }) @@ -434,7 +403,10 @@ type Bar struct { // https://github.com/vbauerster/mpb/issues/136 // // Until that bug is fixed, the Bar is lazily initialized - // after the render delay expires. + // after the render delay expires. In fact, even when the + // bug is fixed, we may just stick with the lazy initialization + // mechanism, as it allows us to set the render delay on a + // per-bar basis, which is not possible with the mpb package. barInitOnce *sync.Once barInitFn func() @@ -473,6 +445,8 @@ func (b *Bar) Incr(n int) { } return default: + // The bar hasn't been initialized yet, so we stash + // the increment count for later use. b.incrStash.Add(int64(n)) } } @@ -506,10 +480,7 @@ func (b *Bar) doStop() { // We *probably* only need to call b.bar.Abort() here? b.bar.SetTotal(-1, true) b.bar.Abort(true) - // b.p.refreshCh <- time.Now() b.bar.Wait() - // b.p.refreshCh <- time.Now() - close(b.barStoppedCh) lg.FromContext(b.p.ctx).Debug("Stopped progress bar") }) diff --git a/libsq/core/progress/style.go b/libsq/core/progress/style.go index 5b48061eb..8ef5b1451 100644 --- a/libsq/core/progress/style.go +++ b/libsq/core/progress/style.go @@ -17,9 +17,6 @@ const ( refreshRate = 150 * time.Millisecond ) -// @download_16b8a3b1: http start ∙●∙ -// @download_16b8a3b1: download ∙ 14.4 MiB / 427.6 MiB 3.4 - // DefaultColors returns the default colors used for the progress bars. func DefaultColors() *Colors { return &Colors{ diff --git a/libsq/dbwriter.go b/libsq/dbwriter.go index 3aea29b93..2f5dacc34 100644 --- a/libsq/dbwriter.go +++ b/libsq/dbwriter.go @@ -142,7 +142,7 @@ func (w *DBWriter) Open(ctx context.Context, cancelFn context.CancelFunc, recMet defer func() { // When the inserter goroutine finishes: // - we close errCh (indicates that the DBWriter is done) - // - and mark wg as done, which the Stop method depends upon. + // - and mark wg as done, which the Wait method depends upon. close(w.errCh) w.wg.Done() }() @@ -164,7 +164,7 @@ func (w *DBWriter) Open(ctx context.Context, cancelFn context.CancelFunc, recMet // Tell batch inserter that we're done sending records close(w.bi.RecordCh) - err = <-w.bi.ErrCh // Stop for batch inserter to complete + err = <-w.bi.ErrCh // Wait for batch inserter to complete if err != nil { lg.FromContext(ctx).Error(err.Error()) w.addErrs(err) diff --git a/libsq/driver/grip.go b/libsq/driver/grip.go index f1c642f8a..e4908747a 100644 --- a/libsq/driver/grip.go +++ b/libsq/driver/grip.go @@ -16,17 +16,15 @@ import ( // encapsulates a sql.DB instance. The realized sql.DB instance can be // accessed via the DB method. type Grip interface { - // DB returns the sql.DB object for this Grip. - // This operation may take a long time if opening the DB requires - // an ingest of data (but note that when an ingest step occurs is - // driver-dependent). + // DB returns the sql.DB object for this Grip. This operation some time + // to complete if opening the DB requires an ingest of data. DB(ctx context.Context) (*sql.DB, error) // SQLDriver returns the underlying database driver. The type of the SQLDriver // may be different from the driver type reported by the Source. SQLDriver() SQLDriver - // FIXME: Add a method: SourceDriver() Driver. + // TODO: Add a method: SourceDriver() Driver. // Source returns the source for which this Grip was opened. Source() *source.Source diff --git a/libsq/driver/grips.go b/libsq/driver/grips.go index 813fbc73a..0264198ed 100644 --- a/libsq/driver/grips.go +++ b/libsq/driver/grips.go @@ -9,16 +9,15 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/ioz/checksum" - "github.com/neilotoole/sq/libsq/core/stringz" - "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/checksum" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source" "github.com/neilotoole/sq/libsq/source/drivertype" ) @@ -51,16 +50,12 @@ func NewGrips(drvrs Provider, files *source.Files, scratchSrcFn ScratchSrcFunc) } } -// Open returns an opened Grip for src. The returned Grip -// may be cached and returned on future invocations for the -// same source (where each source fields is identical). -// Thus, the caller should typically not close -// the Grip: it will be closed via d.Close. +// Open returns an opened Grip for src. The returned Grip may be cached and +// returned on future invocations for the identical source. Thus, the caller +// should typically not close the Grip: it will be closed via d.Close. // // NOTE: This entire logic re caching/not-closing is a bit sketchy, // and needs to be revisited. -// -// Open implements GripOpener. func (gs *Grips) Open(ctx context.Context, src *source.Source) (Grip, error) { lg.FromContext(ctx).Debug(lgm.OpenSrc, lga.Src, src) gs.mu.Lock() @@ -123,8 +118,8 @@ func (gs *Grips) doOpen(ctx context.Context, src *source.Source) (Grip, error) { } // OpenEphemeral returns an ephemeral scratch Grip instance. It is not -// necessary for the caller to close the returned Grip as -// its Close method will be invoked by Grips.Close. +// necessary for the caller to close the returned Grip as its Close method +// will be invoked by Grips.Close. func (gs *Grips) OpenEphemeral(ctx context.Context) (Grip, error) { const msgCloseDB = "Close ephemeral db" gs.mu.Lock() @@ -352,8 +347,7 @@ func (gs *Grips) openCachedGripFor(ctx context.Context, src *source.Source) (bac return backingGrip, true, nil } -// OpenJoin opens an appropriate Grip for use as -// a work DB for joining across sources. +// OpenJoin opens an appropriate Grip for use as a work DB for joining across sources. // // REVISIT: There is much work to be done on this method. Ultimately OpenJoin // should be able to inspect the join srcs and use heuristics to determine diff --git a/libsq/query_test.go b/libsq/query_test.go index 026a79e8a..2ef2b462f 100644 --- a/libsq/query_test.go +++ b/libsq/query_test.go @@ -180,11 +180,9 @@ func doExecQueryTestCase(t *testing.T, tc queryTestCase) { require.NoError(t, err) th := testh.New(t) - sources := th.Grips() - qc := &libsq.QueryContext{ Collection: coll, - Grips: sources, + Grips: th.Grips(), Args: tc.args, } diff --git a/libsq/source/cache.go b/libsq/source/cache.go index 6d2794357..a6de1160d 100644 --- a/libsq/source/cache.go +++ b/libsq/source/cache.go @@ -9,8 +9,6 @@ import ( "strings" "time" - "github.com/neilotoole/sq/libsq/core/progress" - "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" "github.com/neilotoole/sq/libsq/core/ioz/checksum" @@ -19,6 +17,7 @@ import ( "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" "github.com/neilotoole/sq/libsq/core/options" + "github.com/neilotoole/sq/libsq/core/progress" "github.com/neilotoole/sq/libsq/core/stringz" "github.com/neilotoole/sq/libsq/source/drivertype" ) diff --git a/libsq/source/files.go b/libsq/source/files.go index 9e6ad1262..65f47f86a 100644 --- a/libsq/source/files.go +++ b/libsq/source/files.go @@ -11,8 +11,6 @@ import ( "sync" "time" - "github.com/neilotoole/sq/libsq/core/ioz/lockfile" - "github.com/neilotoole/fscache" "github.com/neilotoole/sq/libsq/core/cleanup" @@ -22,6 +20,7 @@ import ( "github.com/neilotoole/sq/libsq/core/ioz/contextio" "github.com/neilotoole/sq/libsq/core/ioz/download" "github.com/neilotoole/sq/libsq/core/ioz/httpz" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" @@ -199,7 +198,7 @@ func (fs *Files) AddStdin(ctx context.Context, f *os.File) error { return errz.Wrapf(err, "failed to add %s to fscache", StdinHandle) } -// addStdin synchronously copies f to fs's cache. f is closed +// addStdin asynchronously copies f to fs's cache. f is closed // when the async copy completes. This method should only be used // for stdin; for regular files, use Files.addRegularFile. func (fs *Files) addStdin(ctx context.Context, handle string, f *os.File) error { diff --git a/libsq/source/files_test.go b/libsq/source/files_test.go index 246242af1..375721026 100644 --- a/libsq/source/files_test.go +++ b/libsq/source/files_test.go @@ -7,13 +7,12 @@ import ( "path/filepath" "testing" - "github.com/neilotoole/sq/drivers/json" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "github.com/neilotoole/sq/drivers/csv" + "github.com/neilotoole/sq/drivers/json" "github.com/neilotoole/sq/drivers/mysql" "github.com/neilotoole/sq/drivers/postgres" "github.com/neilotoole/sq/drivers/sqlite3" diff --git a/libsq/source/lock.go b/libsq/source/lock.go deleted file mode 100644 index 5ae34a42b..000000000 --- a/libsq/source/lock.go +++ /dev/null @@ -1,35 +0,0 @@ -package source - -import ( - "github.com/neilotoole/sq/libsq/core/errz" - "github.com/neilotoole/sq/libsq/core/ioz/lockfile" -) - -// NewLock returns a new source.Lock instance. -// -// REVISIT: We may not actually use source.Lock at all, and -// instead stick with ioz/lockfile.Lockfile. -func NewLock(src *Source, pidfile string) (Lock, error) { - lf, err := lockfile.New(pidfile) - if err != nil { - return Lock{}, errz.Err(err) - } - - return Lock{ - Lockfile: lf, - src: src, - }, nil -} - -type Lock struct { - lockfile.Lockfile - src *Source -} - -func (l Lock) Source() *Source { - return l.src -} - -func (l Lock) String() string { - return l.src.Handle + ": " + string(l.Lockfile) -} diff --git a/testh/testh.go b/testh/testh.go index 53afef8bf..096910236 100644 --- a/testh/testh.go +++ b/testh/testh.go @@ -13,8 +13,6 @@ import ( "sync" "testing" - "github.com/neilotoole/sq/libsq/core/ioz/lockfile" - "github.com/samber/lo" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -39,6 +37,7 @@ import ( "github.com/neilotoole/sq/libsq/core/cleanup" "github.com/neilotoole/sq/libsq/core/errz" "github.com/neilotoole/sq/libsq/core/ioz" + "github.com/neilotoole/sq/libsq/core/ioz/lockfile" "github.com/neilotoole/sq/libsq/core/lg" "github.com/neilotoole/sq/libsq/core/lg/lga" "github.com/neilotoole/sq/libsq/core/lg/lgm" @@ -192,13 +191,13 @@ func (h *Helper) init() { h.addUserDrivers() h.run = &run.Run{ - Stdin: os.Stdin, - Out: os.Stdout, - ErrOut: os.Stdin, - Config: cfg, - ConfigStore: config.DiscardStore{}, - - DriverRegistry: h.registry, + Stdin: os.Stdin, + Out: os.Stdout, + ErrOut: os.Stdin, + Config: cfg, + ConfigStore: config.DiscardStore{}, + OptionsRegistry: optRegistry, + DriverRegistry: h.registry, } }) } @@ -729,8 +728,8 @@ func (h *Helper) addUserDrivers() { userDriverDefs := DriverDefsFrom(h.T, testsrc.PathDriverDefPpl, testsrc.PathDriverDefRSS) // One day we may have more supported user driver genres. - userDriverImporters := map[string]userdriver.ImportFunc{ - xmlud.Genre: xmlud.Import, + userDriverImporters := map[string]userdriver.IngestFunc{ + xmlud.Genre: xmlud.Ingest, } for _, userDriverDef := range userDriverDefs { @@ -748,7 +747,7 @@ func (h *Helper) addUserDrivers() { udp := &userdriver.Provider{ Log: h.Log, DriverDef: userDriverDef, - ImportFn: importFn, + IngestFn: importFn, Ingester: h.grips, Files: h.files, } diff --git a/testh/tu/tu.go b/testh/tu/tu.go index ed99e5c35..59320a111 100644 --- a/testh/tu/tu.go +++ b/testh/tu/tu.go @@ -156,6 +156,9 @@ func SliceFieldKeyValues(keyFieldName, valFieldName string, slice any) map[any]a // // Note that this function uses reflection, and may panic. It is only // to be used by test code. +// +// REVISIT: This function predates generics. It can probably be +// removed, or at a minimum, moved to pkg loz. func AnySlice(slice any) []any { if slice == nil { return nil @@ -306,7 +309,7 @@ func (t *tWriter) Write(p []byte) (n int, err error) { } // Chdir changes the working directory to dir, or if dir is empty, -// to a temp dir. On test end, the original working dir is restored, +// to a temp dir. On test conclusion, the original working dir is restored, // and the temp dir deleted (if applicable). The absolute path // of the changed working dir is returned. func Chdir(t testing.TB, dir string) (absDir string) { From 32cdfa22675d738ab14606eb0d8130e1604c0b1c Mon Sep 17 00:00:00 2001 From: neilotoole Date: Sun, 14 Jan 2024 18:24:57 -0700 Subject: [PATCH 195/195] terminal_windows.go issue, again --- cli/terminal_windows.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cli/terminal_windows.go b/cli/terminal_windows.go index c3bafaf89..40b2573d8 100644 --- a/cli/terminal_windows.go +++ b/cli/terminal_windows.go @@ -4,6 +4,7 @@ import ( "io" "os" + "golang.org/x/sys/windows" "golang.org/x/term" )