diff --git a/auth/auth.go b/auth/auth.go index 24f4fb89..d60e20b7 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -41,6 +41,8 @@ type Provider interface { Validate() error // Test tests if credentials are valid. Test(ctx context.Context) error + // Client returns an authenticated HTTP client + HTTPClient() (*http.Client, error) } var ( @@ -114,9 +116,9 @@ func (s simpleProvider) Test(ctx context.Context) error { ctx, task := trace.NewTask(ctx, "TestAuth") defer task.End() - httpCl, err := chttp.New("https://slack.com", s.Cookies()) + httpCl, err := s.HTTPClient() if err != nil { - return err + return &Error{Err: err} } cl := slack.New(s.Token, slack.OptionHTTPClient(httpCl)) @@ -127,3 +129,7 @@ func (s simpleProvider) Test(ctx context.Context) error { } return nil } + +func (s simpleProvider) HTTPClient() (*http.Client, error) { + return chttp.New("https://slack.com", s.Cookies()) +} diff --git a/auth/auth_error.go b/auth/auth_error.go index d4a137f9..c713b7ef 100644 --- a/auth/auth_error.go +++ b/auth/auth_error.go @@ -9,7 +9,7 @@ type Error struct { } func (ae *Error) Error() string { - return fmt.Sprintf("failed to authenticate: %s", ae.Err) + return fmt.Sprintf("authentication error: %s", ae.Err) } func (ae *Error) Unwrap() error { diff --git a/export/export.go b/export/export.go index 4ad6237f..d0245f00 100644 --- a/export/export.go +++ b/export/export.go @@ -37,7 +37,7 @@ func New(sd *slackdump.Session, fs fsadapter.FS, cfg Config) *Export { if cfg.Logger == nil { cfg.Logger = logger.Default } - network.Logger = cfg.Logger + network.SetLogger(cfg.Logger) se := &Export{ fs: fs, diff --git a/internal/chunk/recorder_test.go b/internal/chunk/recorder_test.go index 59d8ef6e..674d4c69 100644 --- a/internal/chunk/recorder_test.go +++ b/internal/chunk/recorder_test.go @@ -135,7 +135,7 @@ func TestRecorder_worker(t *testing.T) { if time.Since(start) > 50*time.Millisecond { t.Errorf("worker took too long to exit") } - const want = "{\"t\":0,\"ts\":0,\"n\":0,\"id\":\"C123\",\"m\":[{\"text\":\"hello\",\"replace_original\":false,\"delete_original\":false,\"metadata\":{\"event_type\":\"\",\"event_payload\":null},\"blocks\":null}]}\n" + const want = "{\"t\":0,\"ts\":0,\"id\":\"C123\",\"n\":0,\"m\":[{\"text\":\"hello\",\"replace_original\":false,\"delete_original\":false,\"metadata\":{\"event_type\":\"\",\"event_payload\":null},\"blocks\":null}]}\n" if !assert.Equal(t, want, buf.String()) { t.Errorf("unexpected output: %s", buf.String()) diff --git a/internal/network/limiter.go b/internal/network/limiter.go index 98a47a50..27511b2c 100644 --- a/internal/network/limiter.go +++ b/internal/network/limiter.go @@ -14,12 +14,16 @@ const ( Tier2 Tier = 20 Tier3 Tier = 50 // Tier4 Tier = 100 + + // secPerMin is the number of seconds in a minute, it is here to allow easy + // modification of the program, should this value change. + secPerMin = 60.0 ) // NewLimiter returns throttler with rateLimit requests per minute. // optionally caller may specify the boost func NewLimiter(t Tier, burst uint, boost int) *rate.Limiter { - callsPerSec := float64(int(t)+boost) / 60.0 + callsPerSec := float64(int(t)+boost) / secPerMin l := rate.NewLimiter(rate.Limit(callsPerSec), int(burst)) return l } diff --git a/internal/network/network.go b/internal/network/network.go index c37d5a57..87abd614 100644 --- a/internal/network/network.go +++ b/internal/network/network.go @@ -2,13 +2,13 @@ package network import ( "context" + "errors" "fmt" "net/http" "runtime/trace" + "sync" "time" - "errors" - "github.com/slack-go/slack" "golang.org/x/time/rate" @@ -20,22 +20,28 @@ const ( defNumAttempts = 3 ) -// MaxAllowedWaitTime is the maximum time to wait for a transient error. The -// wait time for a transient error depends on the current retry attempt number -// and is calculated as: (attempt+2)^3 seconds, capped at MaxAllowedWaitTime. -var MaxAllowedWaitTime = 5 * time.Minute - -// Logger is the package logger. -var Logger logger.Interface = logger.Default +var ( + // maxAllowedWaitTime is the maximum time to wait for a transient error. + // The wait time for a transient error depends on the current retry + // attempt number and is calculated as: (attempt+2)^3 seconds, capped at + // maxAllowedWaitTime. + maxAllowedWaitTime = 5 * time.Minute + lg logger.Interface = logger.Default + // waitFn returns the amount of time to wait before retrying depending on + // the current attempt. This variable exists to reduce the test time. + waitFn = cubicWait + + mu sync.RWMutex +) // ErrRetryFailed is returned if number of retry attempts exceeded the retry attempts limit and // function wasn't able to complete without errors. -var ErrRetryFailed = errors.New("callback was not able to complete without errors within the allowed number of retries") +var ErrRetryFailed = errors.New("callback was unable to complete without errors within the allowed number of retries") // withRetry will run the callback function fn. If the function returns // slack.RateLimitedError, it will delay, and then call it again up to // maxAttempts times. It will return an error if it runs out of attempts. -func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() error) error { +func WithRetry(ctx context.Context, lim *rate.Limiter, maxAttempts int, fn func() error) error { var ok bool if maxAttempts == 0 { maxAttempts = defNumAttempts @@ -43,7 +49,7 @@ func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() for attempt := 0; attempt < maxAttempts; attempt++ { var err error trace.WithRegion(ctx, "withRetry.wait", func() { - err = l.Wait(ctx) + err = lim.Wait(ctx) }) if err != nil { return err @@ -67,7 +73,7 @@ func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() } else if errors.As(cbErr, &sce) { if sce.Code >= http.StatusInternalServerError && sce.Code <= 599 { // possibly transient error - delay := waitTime(attempt) + delay := waitFn(attempt) tracelogf(ctx, "info", "got server error %d, sleeping %s", sce.Code, delay) time.Sleep(delay) continue @@ -82,30 +88,41 @@ func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() return nil } -// waitTime returns the amount of time to wait before retrying depending on -// the current attempt. This variable exists to reduce the test time. -var waitTime = cubicWait - // cubicWait is the wait time function. Time is calculated as (x+2)^3 seconds, // where x is the current attempt number. The maximum wait time is capped at 5 // minutes. func cubicWait(attempt int) time.Duration { x := attempt + 2 // this is to ensure that we sleep at least 8 seconds. delay := time.Duration(x*x*x) * time.Second - if delay > MaxAllowedWaitTime { - return MaxAllowedWaitTime + if delay > maxAllowedWaitTime { + return maxAllowedWaitTime } return delay } func tracelogf(ctx context.Context, category string, fmt string, a ...any) { + mu.RLock() + defer mu.RUnlock() + trace.Logf(ctx, category, fmt, a...) - l().Debugf(fmt, a...) + lg.Debugf(fmt, a...) } -func l() logger.Interface { - if Logger == nil { - return logger.Default +// SetLogger sets the package logger. +func SetLogger(l logger.Interface) { + mu.Lock() + defer mu.Unlock() + if l == nil { + l = logger.Default + return } - return Logger + lg = l +} + +// SetMaxAllowedWaitTime sets the maximum time to wait for a transient error. +func SetMaxAllowedWaitTime(d time.Duration) { + mu.Lock() + defer mu.Unlock() + + maxAllowedWaitTime = d } diff --git a/internal/network/network_test.go b/internal/network/network_test.go index 6f3f047c..b5034b19 100644 --- a/internal/network/network_test.go +++ b/internal/network/network_test.go @@ -148,9 +148,9 @@ func Test_withRetry(t *testing.T) { } func Test500ErrorHandling(t *testing.T) { - waitTime = func(attempt int) time.Duration { return 50 * time.Millisecond } + waitFn = func(attempt int) time.Duration { return 50 * time.Millisecond } defer func() { - waitTime = cubicWait + waitFn = cubicWait }() var codes = []int{500, 502, 503, 504, 598} @@ -187,8 +187,8 @@ func Test500ErrorHandling(t *testing.T) { } dur := time.Since(start) - if dur < waitTime(testRetryCount-1)-waitThreshold || waitTime(testRetryCount-1)+waitThreshold < dur { - t.Errorf("expected duration to be around %s, got %s", waitTime(testRetryCount), dur) + if dur < waitFn(testRetryCount-1)-waitThreshold || waitFn(testRetryCount-1)+waitThreshold < dur { + t.Errorf("expected duration to be around %s, got %s", waitFn(testRetryCount), dur) } }) @@ -242,8 +242,8 @@ func Test_cubicWait(t *testing.T) { {"attempt 1", args{1}, 27 * time.Second}, {"attempt 2", args{2}, 64 * time.Second}, {"attempt 2", args{4}, 216 * time.Second}, - {"attempt 100", args{5}, MaxAllowedWaitTime}, // check if capped properly - {"attempt 100", args{1000}, MaxAllowedWaitTime}, + {"attempt 100", args{5}, maxAllowedWaitTime}, // check if capped properly + {"attempt 100", args{1000}, maxAllowedWaitTime}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/slackdump.go b/slackdump.go index da26688f..ff86d61e 100644 --- a/slackdump.go +++ b/slackdump.go @@ -12,8 +12,6 @@ import ( "github.com/slack-go/slack" "golang.org/x/time/rate" - "github.com/rusq/chttp" - "github.com/rusq/fsadapter" "github.com/rusq/slackdump/v2/auth" "github.com/rusq/slackdump/v2/internal/network" @@ -120,14 +118,13 @@ func New(ctx context.Context, prov auth.Provider, opts ...Option) (*Session, err return nil, fmt.Errorf("auth provider validation error: %s", err) } - httpCl, err := chttp.New("https://slack.com", prov.Cookies()) + httpCl, err := prov.HTTPClient() if err != nil { return nil, err } - cl := slack.New(prov.SlackToken(), slack.OptionHTTPClient(httpCl)) - authTestResp, err := cl.AuthTestContext(ctx) + authResp, err := cl.AuthTestContext(ctx) if err != nil { return nil, &auth.Error{Err: err} } @@ -137,12 +134,13 @@ func New(ctx context.Context, prov auth.Provider, opts ...Option) (*Session, err cfg: defConfig, uc: new(usercache), - wspInfo: authTestResp, + wspInfo: authResp, log: logger.Default, } for _, opt := range opts { opt(sd) } + network.SetLogger(sd.log) // set the logger for the network package if err := sd.cfg.limits.Validate(); err != nil { var vErr validator.ValidationErrors @@ -152,8 +150,6 @@ func New(ctx context.Context, prov auth.Provider, opts ...Option) (*Session, err return nil, err } - sd.propagateLogger() - return sd, nil } @@ -187,11 +183,6 @@ func withRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() return network.WithRetry(ctx, l, maxAttempts, fn) } -// propagateLogger propagates the slackdump logger to some dumb packages. -func (s *Session) propagateLogger() { - network.Logger = s.log -} - // Info returns a workspace information. Slackdump retrieves workspace // information during the initialisation when performing authentication test, // so no API call is involved at this point.