diff --git a/goreq.go b/goreq.go index 7a949e2..825736d 100644 --- a/goreq.go +++ b/goreq.go @@ -25,6 +25,7 @@ type itimeout interface { Timeout() bool } type Request struct { + Client *http.Client headers []headerTuple cookies []*http.Cookie Method string @@ -243,7 +244,6 @@ func prepareRequestBody(b interface{}) (io.Reader, error) { var DefaultDialer = &net.Dialer{Timeout: 1000 * time.Millisecond} var DefaultTransport http.RoundTripper = &http.Transport{Dial: DefaultDialer.Dial, Proxy: http.ProxyFromEnvironment} -var DefaultClient = &http.Client{Transport: DefaultTransport} var proxyTransport http.RoundTripper var proxyClient *http.Client @@ -286,22 +286,28 @@ func (r Request) WithProxyConnectHeader(name string, value string) Request { } func (r Request) Do() (*Response, error) { - var client = DefaultClient var transport = DefaultTransport + + if r.Client == nil { + // use a client with a cookie jar if necessary. We create a new client not + // to modify the default one. + if r.CookieJar != nil { + r.Client = &http.Client{ + Transport: transport, + Jar: r.CookieJar, + } + } else { + r.Client = &http.Client{ + Transport: transport, + } + } + } + var resUri string var redirectFailed bool r.Method = valueOrDefault(r.Method, "GET") - // use a client with a cookie jar if necessary. We create a new client not - // to modify the default one. - if r.CookieJar != nil { - client = &http.Client{ - Transport: transport, - Jar: r.CookieJar, - } - } - if r.Proxy != "" { proxyUrl, err := url.Parse(r.Proxy) if err != nil { @@ -317,26 +323,25 @@ func (r Request) Do() (*Response, error) { } //If jar is specified new client needs to be built - if proxyTransport == nil || client.Jar != nil { + if proxyTransport == nil || r.Client.Jar != nil { proxyTransport = &http.Transport{ Dial: DefaultDialer.Dial, Proxy: http.ProxyURL(proxyUrl), ProxyConnectHeader: proxyHeader, } - proxyClient = &http.Client{Transport: proxyTransport, Jar: client.Jar} + proxyClient = &http.Client{Transport: proxyTransport, Jar: r.Client.Jar} } else if proxyTransport, ok := proxyTransport.(*http.Transport); ok { proxyTransport.Proxy = http.ProxyURL(proxyUrl) proxyTransport.ProxyConnectHeader = proxyHeader } transport = proxyTransport - client = proxyClient + r.Client = proxyClient } - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - + r.Client.CheckRedirect = func(req *http.Request, via []*http.Request) error { if len(via) > r.MaxRedirects { redirectFailed = true - return errors.New("Error redirecting. MaxRedirects reached") + return errors.New("Error redirecting. MaxRedirects reached.") } resUri = req.URL.String() @@ -367,7 +372,7 @@ func (r Request) Do() (*Response, error) { timeout := false if r.Timeout > 0 { - client.Timeout = r.Timeout + r.Client.Timeout = r.Timeout } if r.ShowDebug { @@ -381,8 +386,7 @@ func (r Request) Do() (*Response, error) { if r.OnBeforeRequest != nil { r.OnBeforeRequest(&r, req) } - res, err := client.Do(req) - + res, err := r.Client.Do(req) if err != nil { if !timeout { if t, ok := err.(itimeout); ok { @@ -403,11 +407,11 @@ func (r Request) Do() (*Response, error) { } else { response = &Response{res, resUri, nil, req} } - } - //If redirect fails and we haven't set a redirect count we shouldn't return an error - if redirectFailed && r.MaxRedirects == 0 { - return response, nil + //If redirect fails and we haven't set a redirect count we shouldn't return an error + if r.MaxRedirects == 0 { + return response, nil + } } return response, &Error{timeout: timeout, Err: err} diff --git a/goreq_test.go b/goreq_test.go index 3b15657..c6475ae 100644 --- a/goreq_test.go +++ b/goreq_test.go @@ -886,7 +886,7 @@ func TestRequest(t *testing.T) { Expect(res.StatusCode).Should(Equal(200)) }) - g.It("Should not create a body by defualt", func() { + g.It("Should not create a body by default", func() { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { b, _ := ioutil.ReadAll(r.Body) Expect(b).Should(HaveLen(0)) @@ -904,13 +904,13 @@ func TestRequest(t *testing.T) { defer ts.Close() req := Request{ + Client: &http.Client{Transport: DefaultTransport}, Insecure: true, Uri: ts.URL, Host: "foobar.com", } res, _ := req.Do() - - Expect(DefaultClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify).Should(Equal(true)) + Expect(req.Client.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify).Should(Equal(true)) Expect(res.StatusCode).Should(Equal(200)) }) g.It("Should work if a different transport is specified", func() { @@ -922,13 +922,14 @@ func TestRequest(t *testing.T) { DefaultTransport = &http.Transport{Dial: DefaultDialer.Dial} req := Request{ + Client: &http.Client{Transport: DefaultTransport}, Insecure: true, Uri: ts.URL, Host: "foobar.com", } res, _ := req.Do() - Expect(DefaultClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify).Should(Equal(true)) + Expect(req.Client.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify).Should(Equal(true)) Expect(res.StatusCode).Should(Equal(200)) DefaultTransport = currentTransport