Skip to content

Commit

Permalink
Override wrong python https proxy on Windows (#371)
Browse files Browse the repository at this point in the history
Summary:
Fixes #355

This is patch to fix a bug embedded in Python `urllib` in TorchData. See: https://bugs.python.org/issue42627

Simply speaking, python uses `https` proxy for urls with `https` but windows platform expects `http` proxy for these urls.

I have no idea how to create such test for windows, I am open for suggestions

Pull Request resolved: #371

Reviewed By: NivekT

Differential Revision: D35935252

Pulled By: ejguan

fbshipit-source-id: 3a23f1e60916942901a2f6aeb05f9d7a8014c5ed
  • Loading branch information
ejguan authored and facebook-github-bot committed Apr 26, 2022
1 parent 8d40227 commit 41e16d2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
4 changes: 1 addition & 3 deletions test/test_remote_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def test_http_reader_iterdatapipe(self):
self.assertTrue(io.BufferedReader, type(stream))

# __len__ Test: returns the length of source DataPipe
source_dp = IterableWrapper([file_url])
http_dp = HttpReader(source_dp)
self.assertEqual(1, len(http_dp))
self.assertEqual(1, len(http_reader_dp))

def test_on_disk_cache_holder_iterdatapipe(self):
tar_file_url = "https://raw.githubusercontent.com/pytorch/data/main/test/_fakedata/csv.tar.gz"
Expand Down
28 changes: 23 additions & 5 deletions torchdata/datapipes/iter/load/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,42 @@
# LICENSE file in the root directory of this source tree.

import re
from typing import Iterator, Optional, Tuple
from urllib.parse import urlparse
import urllib

from typing import Dict, Iterator, Optional, Tuple

import requests

from requests.exceptions import HTTPError, RequestException

from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.utils import StreamWrapper


# TODO: Remove this helper function when https://bugs.python.org/issue42627 is resolved
def _get_proxies() -> Optional[Dict[str, str]]:
import os

if os.name == "nt":
proxies = urllib.request.getproxies()
address = proxies.get("https")
# The default proxy type of Windows is HTTP
if address and address.startswith("https"):
address = "http" + address[5:]
proxies["https"] = address
return proxies
return None


def _get_response_from_http(url: str, *, timeout: Optional[float]) -> Tuple[str, StreamWrapper]:
try:
with requests.Session() as session:
proxies = _get_proxies()
if timeout is None:
r = session.get(url, stream=True)
r = session.get(url, stream=True, proxies=proxies)
else:
r = session.get(url, timeout=timeout, stream=True)
r = session.get(url, timeout=timeout, stream=True, proxies=proxies)
return url, StreamWrapper(r.raw)
except HTTPError as e:
raise Exception(f"Could not get the file. [HTTP Error] {e.response}.")
Expand Down Expand Up @@ -166,7 +184,7 @@ def __init__(self, source_datapipe: IterDataPipe[str], *, timeout: Optional[floa

def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
for url in self.source_datapipe:
parts = urlparse(url)
parts = urllib.parse.urlparse(url)

if re.match(r"(drive|docs)[.]google[.]com", parts.netloc):
yield _get_response_from_google_drive(url, timeout=self.timeout)
Expand Down

0 comments on commit 41e16d2

Please sign in to comment.