Skip to content

Commit

Permalink
make lazy wheel work against tensorflow-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
cosmicexplorer committed Aug 7, 2023
1 parent 8c35424 commit 33f2431
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 24 deletions.
98 changes: 75 additions & 23 deletions src/pip/_internal/network/lazy_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

__all__ = ["HTTPRangeRequestUnsupported", "dist_from_wheel_url"]
__all__ = ["HTTPRangeRequestUnsupported", "dist_from_wheel_url", "LazyHTTPFile"]

import io
import logging
Expand All @@ -22,6 +22,7 @@
from pip._internal.metadata import BaseDistribution, MemoryWheel, get_wheel_distribution
from pip._internal.network.session import PipSession as Session
from pip._internal.network.utils import HEADERS
from pip._internal.utils.logging import indent_log

logger = logging.getLogger(__name__)

Expand All @@ -40,6 +41,11 @@ def dist_from_wheel_url(name: str, url: str, session: Session) -> BaseDistributi
"""
try:
with LazyHTTPFile(url, session) as lazy_file:
with indent_log():
logger.debug("begin prefetching for %s", name)
lazy_file.prefetch_contiguous_dist_info(name)
logger.debug("done prefetching for %s", name)

# For read-only ZIP files, ZipFile only needs methods read,
# seek, seekable and tell, not the whole IO protocol.
wheel = MemoryWheel(lazy_file.name, lazy_file)
Expand Down Expand Up @@ -145,6 +151,11 @@ def __next__(self) -> bytes:
raise NotImplementedError


# The central directory for tensorflow_gpu-2.5.3-cp38-cp38-manylinux2010_x86_64.whl is
# 944931 bytes, for a 459424488 byte file (about 486x as large).
_DEFAULT_INITIAL_FETCH = 1_000_000


class LazyHTTPFile(ReadOnlyIOWrapper):
"""File-like object mapped to a ZIP file over HTTP.
Expand All @@ -159,7 +170,10 @@ class LazyHTTPFile(ReadOnlyIOWrapper):
_domains_without_negative_range: ClassVar[set[str]] = set()

def __init__(
self, url: str, session: Session, initial_chunk_size: int = CONTENT_CHUNK_SIZE
self,
url: str,
session: Session,
initial_chunk_size: int = _DEFAULT_INITIAL_FETCH,
) -> None:
# Add delete=False and print the file's `.name` to debug invalid virtual zips.
super().__init__(cast(BinaryIO, NamedTemporaryFile()))
Expand All @@ -172,21 +186,20 @@ def __init__(

self._length, initial_chunk = self._extract_content_length(initial_chunk_size)
self.truncate(self._length)
# The central directory for
# tensorflow_gpu-2.5.3-cp38-cp38-manylinux2010_x86_64.whl is 944931 bytes, for
# a 459424488 byte file (about 486x as large).
self._minimum_fetch_granularity = max(initial_chunk_size, self._length // 400)
if initial_chunk is None:
# If we could not download any file contents yet (e.g. if negative byte
# ranges were not supported), then download all of this at once, hopefully
# pulling in the entire central directory.
initial_start = max(0, self._length - self._minimum_fetch_granularity)
initial_start = max(0, self._length - initial_chunk_size)
self._download(initial_start, self._length)
else:
self.seek(-len(initial_chunk), io.SEEK_END)
self._file.write(initial_chunk)
self._left.append(self._length - len(initial_chunk))
self._right.append(self._length - 1)
# If we could download file contents, then write them to the end of the
# file and set up our bisect boundaries by hand.
with self._stay():
self.seek(-len(initial_chunk), io.SEEK_END)
self._file.write(initial_chunk)
self._left.append(self._length - len(initial_chunk))
self._right.append(self._length - 1)

def read(self, size: int = -1) -> bytes:
"""Read up to size bytes from the object and return them.
Expand All @@ -195,17 +208,17 @@ def read(self, size: int = -1) -> bytes:
all bytes until EOF are returned. Fewer than
size bytes may be returned if EOF is reached.
"""
# BUG does not download correctly if size is unspecified
cur = self.tell()
logger.debug("read size %d at %d", size, cur)
if size < 0:
assert cur <= self._length
download_size = self._length - cur
elif size == 0:
return b''
return b""
else:
download_size = max(size, self._minimum_fetch_granularity)
download_size = size
stop = min(cur + download_size, self._length)
self._download(cur, stop - 1)
self._download(cur, stop)
return self._file.read(size)

def __enter__(self) -> LazyHTTPFile:
Expand All @@ -221,18 +234,20 @@ def _content_length_from_head(self) -> int:
head = self._session.head(self._url, headers=HEADERS)
head.raise_for_status()
assert head.status_code == codes.ok
return int(head.headers["content-length"])
return int(head.headers["Content-Length"])

@staticmethod
def _parse_full_length_from_content_range(arg: str) -> Optional[int]:
if m := re.match(r"bytes [^/]+/([0-9]+)", arg):
m = re.match(r"bytes [^/]+/([0-9]+)", arg)
if m is not None:
return int(m.group(1))
return None

def _try_initial_chunk_request(self, initial_chunk_size: int) -> tuple[int, bytes]:
headers = HEADERS.copy()
# Perform a negative range index, which is not supported by some servers.
headers["Range"] = f"bytes=-{initial_chunk_size}"
logger.debug("initial bytes request: %s", headers["Range"])
# TODO: Get range requests to be correctly cached
headers["Cache-Control"] = "no-cache"
# TODO: If-Match (etag) to detect file changed during fetch would be a
Expand All @@ -242,7 +257,7 @@ def _try_initial_chunk_request(self, initial_chunk_size: int) -> tuple[int, byte
tail = self._session.get(self._url, headers=headers)
tail.raise_for_status()

response_length = int(tail.headers["content-length"])
response_length = int(tail.headers["Content-Length"])
assert response_length == len(tail.content)

code = tail.status_code
Expand All @@ -255,12 +270,15 @@ def _try_initial_chunk_request(self, initial_chunk_size: int) -> tuple[int, byte
elif code != codes.partial_content:
raise HTTPRangeRequestUnsupported("did not receive partial content or ok")

range_arg = tail.headers["content-range"]
if file_length := self._parse_full_length_from_content_range(range_arg):
range_arg = tail.headers["Content-Range"]
file_length = self._parse_full_length_from_content_range(range_arg)
if file_length is not None:
return (file_length, tail.content)
raise HTTPRangeRequestUnsupported(f"could not parse content-range: {range_arg}")

def _extract_content_length(self, initial_chunk_size: int) -> tuple[int, Optional[bytes]]:
def _extract_content_length(
self, initial_chunk_size: int
) -> tuple[int, Optional[bytes]]:
domain = urlparse(self._url).netloc
if domain in self._domains_without_negative_range:
return (self._content_length_from_head(), None)
Expand All @@ -287,7 +305,7 @@ def _extract_content_length(self, initial_chunk_size: int) -> tuple[int, Optiona
if code == codes.requested_range_not_satisfiable:
# In this case, we don't have any file content yet, but we do know the
# size the file will be, so we can return that and exit here.
range_arg = resp.headers["content-range"]
range_arg = resp.headers["Content-Range"]
if length := self._parse_full_length_from_content_range(range_arg):
return (length, None)
raise HTTPRangeRequestUnsupported(
Expand Down Expand Up @@ -330,7 +348,7 @@ def _stream_response(self, start: int, end: int) -> Response:
# https://www.rfc-editor.org/rfc/rfc9110#field.content-range
headers = HEADERS.copy()
headers["Range"] = f"bytes={start}-{end}"
logger.debug("%s", headers["Range"])
logger.debug("streamed bytes request: %s", headers["Range"])
# TODO: Get range requests to be correctly cached
headers["Cache-Control"] = "no-cache"
# TODO: If-Match (etag) to detect file changed during fetch would be a
Expand Down Expand Up @@ -364,6 +382,8 @@ def _merge(

def _download(self, start: int, end: int) -> None:
"""Download bytes from start to end inclusively."""
# Reducing by 1 to get an inclusive end range.
end -= 1
with self._stay():
left = bisect_left(self._right, start)
right = bisect_right(self._left, end)
Expand All @@ -372,3 +392,35 @@ def _download(self, start: int, end: int) -> None:
self.seek(start)
for chunk in response.iter_content(CONTENT_CHUNK_SIZE):
self._file.write(chunk)

def prefetch_contiguous_dist_info(self, name: str) -> None:
"""
Read contents of entire dist-info section of wheel.
pip will read every entry in this directory when generating a dist from a wheel,
so prepopulating the file contents avoids waiting for multiple range requests.
"""
dist_info_prefix = re.compile(r"^[^/]*\.dist-info/")
start: Optional[int] = None
end: Optional[int] = None

zf = ZipFile(self)

for info in zf.infolist():
if start is None:
if dist_info_prefix.search(info.filename):
start = info.header_offset
continue
else:
if not dist_info_prefix.search(info.filename):
end = info.header_offset
break
if start is None:
raise UnsupportedWheel(
f"no {dist_info_prefix} directory found for {name} in {self.name}"
)
# If the last entries of the zip are the .dist-info/ dir (as usual), then give
# us everything until the start of the central directory.
if end is None:
end = zf.start_dir
self._download(start, end)
5 changes: 4 additions & 1 deletion src/pip/_internal/utils/wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pip._vendor.packaging.utils import canonicalize_name

from pip._internal.exceptions import UnsupportedWheel
from pip._internal.network.lazy_wheel import LazyHTTPFile

VERSION_COMPATIBLE = (1, 0)

Expand Down Expand Up @@ -69,8 +70,10 @@ def wheel_dist_info_dir(source: ZipFile, name: str) -> str:


def read_wheel_metadata_file(source: ZipFile, path: str) -> bytes:
if isinstance(source.fp, LazyHTTPFile):
logger.debug("extracting entry '%s' from lazy zip '%s'", path, source.fp.name)

try:
logger.debug("extracting entry '%s' from zip '%s'", path, source.fp.name)
return source.read(path)
# BadZipFile for general corruption, KeyError for missing entry,
# and RuntimeError for password-protected files
Expand Down

0 comments on commit 33f2431

Please sign in to comment.