Skip to content

Commit

Permalink
some cosmetic and usability changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rusq committed Mar 26, 2023
1 parent 6f02abb commit 204024c
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 49 deletions.
10 changes: 8 additions & 2 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type Provider interface {
Validate() error
// Test tests if credentials are valid.
Test(ctx context.Context) error
// Client returns an authenticated HTTP client
HTTPClient() (*http.Client, error)
}

var (
Expand Down Expand Up @@ -114,9 +116,9 @@ func (s simpleProvider) Test(ctx context.Context) error {
ctx, task := trace.NewTask(ctx, "TestAuth")
defer task.End()

httpCl, err := chttp.New("https://slack.com", s.Cookies())
httpCl, err := s.HTTPClient()
if err != nil {
return err
return &Error{Err: err}
}
cl := slack.New(s.Token, slack.OptionHTTPClient(httpCl))

Expand All @@ -127,3 +129,7 @@ func (s simpleProvider) Test(ctx context.Context) error {
}
return nil
}

func (s simpleProvider) HTTPClient() (*http.Client, error) {
return chttp.New("https://slack.com", s.Cookies())
}
2 changes: 1 addition & 1 deletion auth/auth_error.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ type Error struct {
}

func (ae *Error) Error() string {
return fmt.Sprintf("failed to authenticate: %s", ae.Err)
return fmt.Sprintf("authentication error: %s", ae.Err)
}

func (ae *Error) Unwrap() error {
Expand Down
2 changes: 1 addition & 1 deletion export/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func New(sd *slackdump.Session, fs fsadapter.FS, cfg Config) *Export {
if cfg.Logger == nil {
cfg.Logger = logger.Default
}
network.Logger = cfg.Logger
network.SetLogger(cfg.Logger)

se := &Export{
fs: fs,
Expand Down
2 changes: 1 addition & 1 deletion internal/chunk/recorder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func TestRecorder_worker(t *testing.T) {
if time.Since(start) > 50*time.Millisecond {
t.Errorf("worker took too long to exit")
}
const want = "{\"t\":0,\"ts\":0,\"n\":0,\"id\":\"C123\",\"m\":[{\"text\":\"hello\",\"replace_original\":false,\"delete_original\":false,\"metadata\":{\"event_type\":\"\",\"event_payload\":null},\"blocks\":null}]}\n"
const want = "{\"t\":0,\"ts\":0,\"id\":\"C123\",\"n\":0,\"m\":[{\"text\":\"hello\",\"replace_original\":false,\"delete_original\":false,\"metadata\":{\"event_type\":\"\",\"event_payload\":null},\"blocks\":null}]}\n"

if !assert.Equal(t, want, buf.String()) {
t.Errorf("unexpected output: %s", buf.String())
Expand Down
6 changes: 5 additions & 1 deletion internal/network/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@ const (
Tier2 Tier = 20
Tier3 Tier = 50
// Tier4 Tier = 100

// secPerMin is the number of seconds in a minute, it is here to allow easy
// modification of the program, should this value change.
secPerMin = 60.0
)

// NewLimiter returns throttler with rateLimit requests per minute.
// optionally caller may specify the boost
func NewLimiter(t Tier, burst uint, boost int) *rate.Limiter {
callsPerSec := float64(int(t)+boost) / 60.0
callsPerSec := float64(int(t)+boost) / secPerMin
l := rate.NewLimiter(rate.Limit(callsPerSec), int(burst))
return l
}
65 changes: 41 additions & 24 deletions internal/network/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package network

import (
"context"
"errors"
"fmt"
"net/http"
"runtime/trace"
"sync"
"time"

"errors"

"github.com/slack-go/slack"
"golang.org/x/time/rate"

Expand All @@ -20,30 +20,36 @@ const (
defNumAttempts = 3
)

// MaxAllowedWaitTime is the maximum time to wait for a transient error. The
// wait time for a transient error depends on the current retry attempt number
// and is calculated as: (attempt+2)^3 seconds, capped at MaxAllowedWaitTime.
var MaxAllowedWaitTime = 5 * time.Minute

// Logger is the package logger.
var Logger logger.Interface = logger.Default
var (
// maxAllowedWaitTime is the maximum time to wait for a transient error.
// The wait time for a transient error depends on the current retry
// attempt number and is calculated as: (attempt+2)^3 seconds, capped at
// maxAllowedWaitTime.
maxAllowedWaitTime = 5 * time.Minute
lg logger.Interface = logger.Default
// waitFn returns the amount of time to wait before retrying depending on
// the current attempt. This variable exists to reduce the test time.
waitFn = cubicWait

mu sync.RWMutex
)

// ErrRetryFailed is returned if number of retry attempts exceeded the retry attempts limit and
// function wasn't able to complete without errors.
var ErrRetryFailed = errors.New("callback was not able to complete without errors within the allowed number of retries")
var ErrRetryFailed = errors.New("callback was unable to complete without errors within the allowed number of retries")

// withRetry will run the callback function fn. If the function returns
// slack.RateLimitedError, it will delay, and then call it again up to
// maxAttempts times. It will return an error if it runs out of attempts.
func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func() error) error {
func WithRetry(ctx context.Context, lim *rate.Limiter, maxAttempts int, fn func() error) error {
var ok bool
if maxAttempts == 0 {
maxAttempts = defNumAttempts
}
for attempt := 0; attempt < maxAttempts; attempt++ {
var err error
trace.WithRegion(ctx, "withRetry.wait", func() {
err = l.Wait(ctx)
err = lim.Wait(ctx)
})
if err != nil {
return err
Expand All @@ -67,7 +73,7 @@ func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func()
} else if errors.As(cbErr, &sce) {
if sce.Code >= http.StatusInternalServerError && sce.Code <= 599 {
// possibly transient error
delay := waitTime(attempt)
delay := waitFn(attempt)
tracelogf(ctx, "info", "got server error %d, sleeping %s", sce.Code, delay)
time.Sleep(delay)
continue
Expand All @@ -82,30 +88,41 @@ func WithRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func()
return nil
}

// waitTime returns the amount of time to wait before retrying depending on
// the current attempt. This variable exists to reduce the test time.
var waitTime = cubicWait

// cubicWait is the wait time function. Time is calculated as (x+2)^3 seconds,
// where x is the current attempt number. The maximum wait time is capped at 5
// minutes.
func cubicWait(attempt int) time.Duration {
x := attempt + 2 // this is to ensure that we sleep at least 8 seconds.
delay := time.Duration(x*x*x) * time.Second
if delay > MaxAllowedWaitTime {
return MaxAllowedWaitTime
if delay > maxAllowedWaitTime {
return maxAllowedWaitTime
}
return delay
}

func tracelogf(ctx context.Context, category string, fmt string, a ...any) {
mu.RLock()
defer mu.RUnlock()

trace.Logf(ctx, category, fmt, a...)
l().Debugf(fmt, a...)
lg.Debugf(fmt, a...)
}

func l() logger.Interface {
if Logger == nil {
return logger.Default
// SetLogger sets the package logger.
func SetLogger(l logger.Interface) {
mu.Lock()
defer mu.Unlock()
if l == nil {
l = logger.Default
return
}
return Logger
lg = l
}

// SetMaxAllowedWaitTime sets the maximum time to wait for a transient error.
func SetMaxAllowedWaitTime(d time.Duration) {
mu.Lock()
defer mu.Unlock()

maxAllowedWaitTime = d
}
12 changes: 6 additions & 6 deletions internal/network/network_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ func Test_withRetry(t *testing.T) {
}

func Test500ErrorHandling(t *testing.T) {
waitTime = func(attempt int) time.Duration { return 50 * time.Millisecond }
waitFn = func(attempt int) time.Duration { return 50 * time.Millisecond }
defer func() {
waitTime = cubicWait
waitFn = cubicWait
}()

var codes = []int{500, 502, 503, 504, 598}
Expand Down Expand Up @@ -187,8 +187,8 @@ func Test500ErrorHandling(t *testing.T) {
}

dur := time.Since(start)
if dur < waitTime(testRetryCount-1)-waitThreshold || waitTime(testRetryCount-1)+waitThreshold < dur {
t.Errorf("expected duration to be around %s, got %s", waitTime(testRetryCount), dur)
if dur < waitFn(testRetryCount-1)-waitThreshold || waitFn(testRetryCount-1)+waitThreshold < dur {
t.Errorf("expected duration to be around %s, got %s", waitFn(testRetryCount), dur)
}

})
Expand Down Expand Up @@ -242,8 +242,8 @@ func Test_cubicWait(t *testing.T) {
{"attempt 1", args{1}, 27 * time.Second},
{"attempt 2", args{2}, 64 * time.Second},
{"attempt 2", args{4}, 216 * time.Second},
{"attempt 100", args{5}, MaxAllowedWaitTime}, // check if capped properly
{"attempt 100", args{1000}, MaxAllowedWaitTime},
{"attempt 100", args{5}, maxAllowedWaitTime}, // check if capped properly
{"attempt 100", args{1000}, maxAllowedWaitTime},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
17 changes: 4 additions & 13 deletions slackdump.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ import (
"github.com/slack-go/slack"
"golang.org/x/time/rate"

"github.com/rusq/chttp"

"github.com/rusq/fsadapter"
"github.com/rusq/slackdump/v2/auth"
"github.com/rusq/slackdump/v2/internal/network"
Expand Down Expand Up @@ -120,14 +118,13 @@ func New(ctx context.Context, prov auth.Provider, opts ...Option) (*Session, err
return nil, fmt.Errorf("auth provider validation error: %s", err)
}

httpCl, err := chttp.New("https://slack.com", prov.Cookies())
httpCl, err := prov.HTTPClient()
if err != nil {
return nil, err
}

cl := slack.New(prov.SlackToken(), slack.OptionHTTPClient(httpCl))

authTestResp, err := cl.AuthTestContext(ctx)
authResp, err := cl.AuthTestContext(ctx)
if err != nil {
return nil, &auth.Error{Err: err}
}
Expand All @@ -137,12 +134,13 @@ func New(ctx context.Context, prov auth.Provider, opts ...Option) (*Session, err
cfg: defConfig,
uc: new(usercache),

wspInfo: authTestResp,
wspInfo: authResp,
log: logger.Default,
}
for _, opt := range opts {
opt(sd)
}
network.SetLogger(sd.log) // set the logger for the network package

if err := sd.cfg.limits.Validate(); err != nil {
var vErr validator.ValidationErrors
Expand All @@ -152,8 +150,6 @@ func New(ctx context.Context, prov auth.Provider, opts ...Option) (*Session, err
return nil, err
}

sd.propagateLogger()

return sd, nil
}

Expand Down Expand Up @@ -187,11 +183,6 @@ func withRetry(ctx context.Context, l *rate.Limiter, maxAttempts int, fn func()
return network.WithRetry(ctx, l, maxAttempts, fn)
}

// propagateLogger propagates the slackdump logger to some dumb packages.
func (s *Session) propagateLogger() {
network.Logger = s.log
}

// Info returns a workspace information. Slackdump retrieves workspace
// information during the initialisation when performing authentication test,
// so no API call is involved at this point.
Expand Down

0 comments on commit 204024c

Please sign in to comment.