From 3484db65c7cacdf3cb1ad7e64936455924019f41 Mon Sep 17 00:00:00 2001 From: Carlana Johnson Date: Sun, 21 Jul 2024 11:11:07 -0400 Subject: [PATCH] Add ErrorTransport --- transport.go | 9 +++++++++ transport_example_test.go | 6 +++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/transport.go b/transport.go index 7b81f08..51dd2c0 100644 --- a/transport.go +++ b/transport.go @@ -115,3 +115,12 @@ func DoerTransport(cl interface { }) Transport { return RoundTripFunc(cl.Do) } + +// ErrorTransport always returns the specified error instead of connecting. +// It is intended for use in testing +// or to prevent accidental use of http.DefaultClient. +func ErrorTransport(err error) Transport { + return RoundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, err + }) +} diff --git a/transport_example_test.go b/transport_example_test.go index 737351b..256108f 100644 --- a/transport_example_test.go +++ b/transport_example_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "crypto/md5" + "errors" "fmt" "io" "net/http" @@ -111,9 +112,8 @@ func ExampleLogTransport() { fmt.Println("Error!", err) } // Works for bad responses too - baseTrans = requests.RoundTripFunc(func(req *http.Request) (*http.Response, error) { - return nil, fmt.Errorf("can't connect") - }) + baseTrans = requests.ErrorTransport(errors.New("can't connect")) + trans = requests.LogTransport(baseTrans, logger) if err := requests.