Skip to content

Commit

Permalink
fix franela#131: Concurrency issue with CheckRedirect of DefaultClient
Browse files Browse the repository at this point in the history
  • Loading branch information
vvelikodny authored and Bruno Barbosa committed Feb 1, 2019
1 parent 3dcc108 commit 25124ff
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
52 changes: 28 additions & 24 deletions goreq.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type itimeout interface {
Timeout() bool
}
type Request struct {
Client *http.Client
headers []headerTuple
cookies []*http.Cookie
Method string
Expand Down Expand Up @@ -243,7 +244,6 @@ func prepareRequestBody(b interface{}) (io.Reader, error) {

var DefaultDialer = &net.Dialer{Timeout: 1000 * time.Millisecond}
var DefaultTransport http.RoundTripper = &http.Transport{Dial: DefaultDialer.Dial, Proxy: http.ProxyFromEnvironment}
var DefaultClient = &http.Client{Transport: DefaultTransport}

var proxyTransport http.RoundTripper
var proxyClient *http.Client
Expand Down Expand Up @@ -286,22 +286,28 @@ func (r Request) WithProxyConnectHeader(name string, value string) Request {
}

func (r Request) Do() (*Response, error) {
var client = DefaultClient
var transport = DefaultTransport

if r.Client == nil {
// use a client with a cookie jar if necessary. We create a new client not
// to modify the default one.
if r.CookieJar != nil {
r.Client = &http.Client{
Transport: transport,
Jar: r.CookieJar,
}
} else {
r.Client = &http.Client{
Transport: transport,
}
}
}

var resUri string
var redirectFailed bool

r.Method = valueOrDefault(r.Method, "GET")

// use a client with a cookie jar if necessary. We create a new client not
// to modify the default one.
if r.CookieJar != nil {
client = &http.Client{
Transport: transport,
Jar: r.CookieJar,
}
}

if r.Proxy != "" {
proxyUrl, err := url.Parse(r.Proxy)
if err != nil {
Expand All @@ -317,26 +323,25 @@ func (r Request) Do() (*Response, error) {
}

//If jar is specified new client needs to be built
if proxyTransport == nil || client.Jar != nil {
if proxyTransport == nil || r.Client.Jar != nil {
proxyTransport = &http.Transport{
Dial: DefaultDialer.Dial,
Proxy: http.ProxyURL(proxyUrl),
ProxyConnectHeader: proxyHeader,
}
proxyClient = &http.Client{Transport: proxyTransport, Jar: client.Jar}
proxyClient = &http.Client{Transport: proxyTransport, Jar: r.Client.Jar}
} else if proxyTransport, ok := proxyTransport.(*http.Transport); ok {
proxyTransport.Proxy = http.ProxyURL(proxyUrl)
proxyTransport.ProxyConnectHeader = proxyHeader
}
transport = proxyTransport
client = proxyClient
r.Client = proxyClient
}

client.CheckRedirect = func(req *http.Request, via []*http.Request) error {

r.Client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > r.MaxRedirects {
redirectFailed = true
return errors.New("Error redirecting. MaxRedirects reached")
return errors.New("Error redirecting. MaxRedirects reached.")
}

resUri = req.URL.String()
Expand Down Expand Up @@ -367,7 +372,7 @@ func (r Request) Do() (*Response, error) {

timeout := false
if r.Timeout > 0 {
client.Timeout = r.Timeout
r.Client.Timeout = r.Timeout
}

if r.ShowDebug {
Expand All @@ -381,8 +386,7 @@ func (r Request) Do() (*Response, error) {
if r.OnBeforeRequest != nil {
r.OnBeforeRequest(&r, req)
}
res, err := client.Do(req)

res, err := r.Client.Do(req)
if err != nil {
if !timeout {
if t, ok := err.(itimeout); ok {
Expand All @@ -403,11 +407,11 @@ func (r Request) Do() (*Response, error) {
} else {
response = &Response{res, resUri, nil, req}
}
}

//If redirect fails and we haven't set a redirect count we shouldn't return an error
if redirectFailed && r.MaxRedirects == 0 {
return response, nil
//If redirect fails and we haven't set a redirect count we shouldn't return an error
if r.MaxRedirects == 0 {
return response, nil
}
}

return response, &Error{timeout: timeout, Err: err}
Expand Down
9 changes: 5 additions & 4 deletions goreq_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,7 @@ func TestRequest(t *testing.T) {
Expect(res.StatusCode).Should(Equal(200))
})

g.It("Should not create a body by defualt", func() {
g.It("Should not create a body by default", func() {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
b, _ := ioutil.ReadAll(r.Body)
Expect(b).Should(HaveLen(0))
Expand All @@ -904,13 +904,13 @@ func TestRequest(t *testing.T) {
defer ts.Close()

req := Request{
Client: &http.Client{Transport: DefaultTransport},
Insecure: true,
Uri: ts.URL,
Host: "foobar.com",
}
res, _ := req.Do()

Expect(DefaultClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify).Should(Equal(true))
Expect(req.Client.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify).Should(Equal(true))
Expect(res.StatusCode).Should(Equal(200))
})
g.It("Should work if a different transport is specified", func() {
Expand All @@ -922,13 +922,14 @@ func TestRequest(t *testing.T) {
DefaultTransport = &http.Transport{Dial: DefaultDialer.Dial}

req := Request{
Client: &http.Client{Transport: DefaultTransport},
Insecure: true,
Uri: ts.URL,
Host: "foobar.com",
}
res, _ := req.Do()

Expect(DefaultClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify).Should(Equal(true))
Expect(req.Client.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify).Should(Equal(true))
Expect(res.StatusCode).Should(Equal(200))

DefaultTransport = currentTransport
Expand Down

0 comments on commit 25124ff

Please sign in to comment.