-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Vision specific extension types (#146)
* image extension array and scalar * break apart lance.types * lint * add benchmark for vectorized IOU computation * box2darray with vectorized iou * fix benchmarks * address PR comments * minor refactor * test label array * lint * switch to fixed sized list array * fix box2darray * lint * minor fix
- Loading branch information
1 parent
73855d1
commit 0758195
Showing
16 changed files
with
917 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[bumpversion] | ||
current_version = 0.0.5 | ||
commit = True | ||
tag = True | ||
parse = (?P<major>\d+)\.(?P<minor>\d+)(\.(?P<patch>\d+))?(\.(?P<release>[a-z]+)(?P<build>\d+))? | ||
serialize = | ||
{major}.{minor}.{patch}.{release}{build} | ||
{major}.{minor}.{patch} | ||
message = "Bump version for release: {current_version} -> {new_version}" | ||
|
||
[bumpversion:part:release] | ||
first_value = dev | ||
optional_value = final | ||
values = | ||
dev | ||
final | ||
|
||
[bumpversion:file:./lance/version.py] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) 2022. Lance Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import numpy as np | ||
import pyarrow as pa | ||
from lance.types.box import Box2dArray, Box2dType | ||
|
||
|
||
def iou(is_vectorized: bool, num_boxes: int = 100): | ||
if is_vectorized: | ||
return iou_vectorized(num_boxes) | ||
return iou_naive(num_boxes) | ||
|
||
|
||
def iou_naive(num_boxes: int): | ||
xmin_arr = np.random.randn(num_boxes) + 1 | ||
ymin_arr = np.random.randn(num_boxes) + 1 | ||
xmax_arr = (np.random.randn(num_boxes) + 10) * 10 | ||
ymax_arr = (np.random.randn(num_boxes) + 10) * 10 | ||
ious = np.zeros((num_boxes, num_boxes)) | ||
for i in range(num_boxes): | ||
for j in range(num_boxes): | ||
xmin = max(xmin_arr[i], xmin_arr[j]) | ||
ymin = max(ymin_arr[i], ymin_arr[j]) | ||
xmax = min(xmax_arr[i], xmax_arr[j]) | ||
ymax = min(ymax_arr[i], ymax_arr[j]) | ||
# compute the area of intersection rectangle | ||
inter = max(0, xmax - xmin + 1) * max(0, ymax - ymin + 1) | ||
# compute the area of both the prediction and ground-truth | ||
# rectangles | ||
area_i = ((xmax_arr[i] - xmin_arr[i] + 1) * | ||
(ymax_arr[i] - ymin_arr[i] + 1)) | ||
area_j = ((xmax_arr[j] - xmin_arr[j] + 1) * | ||
(ymax_arr[j] - ymin_arr[j] + 1)) | ||
# compute the intersection over union by taking the intersection | ||
# area and dividing it by the sum of prediction + ground-truth | ||
# areas - the interesection area | ||
ious[i, j] = inter / float(area_i + area_j - inter) | ||
return ious | ||
|
||
|
||
def iou_vectorized(num_boxes: int): | ||
xmin_arr = np.random.randn(num_boxes) + 1 | ||
ymin_arr = np.random.randn(num_boxes) + 1 | ||
xmax_arr = (np.random.randn(num_boxes) + 10) * 10 | ||
ymax_arr = (np.random.randn(num_boxes) + 10) * 10 | ||
storage = pa.StructArray.from_arrays( | ||
[xmin_arr, ymin_arr, xmax_arr, ymax_arr], | ||
names=["xmin", "ymin", "xmax", "ymax"] | ||
) | ||
box_arr = Box2dArray.from_storage(Box2dType(), storage) | ||
return box_arr.iou(box_arr) | ||
|
||
|
||
if __name__ == "__main__": | ||
import time | ||
|
||
n_repeats = 10 | ||
results = {} | ||
for num_boxes in [10, 100, 1000, 10000]: | ||
for is_vectorized in [True, False]: | ||
repeats = [] | ||
for i in range(n_repeats): | ||
start = time.time_ns() | ||
iou(is_vectorized, num_boxes) | ||
end = time.time_ns() | ||
duration_ns = end - start | ||
repeats.append(duration_ns) | ||
results[(num_boxes, is_vectorized)] = np.mean(repeats) | ||
print(results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (c) 2022. Lance Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""IO utilities""" | ||
import os | ||
import shutil | ||
from io import BytesIO | ||
from pathlib import Path | ||
from typing import IO, BinaryIO, Optional, Tuple, Union | ||
from urllib.parse import urlparse | ||
|
||
import requests | ||
from pyarrow import fs | ||
from requests.auth import AuthBase | ||
|
||
import lance | ||
from lance.logging import logger | ||
|
||
USER_AGENT = f"User-Agent: Lance/{lance.__version__} ([email protected])" | ||
|
||
|
||
def open_uri( | ||
uri: Union[str, Path], | ||
mode: str = "rb", | ||
http_auth: Optional[Union[AuthBase, Tuple[str, str]]] = None, | ||
http_headers: Optional[dict] = None, | ||
) -> IO: | ||
"""Open URI for reading. Supports the following URI formats: | ||
- File System: ``/path/to/file`` or ``file:///path/to/file`` | ||
- Http(s): ``http://`` or ``https://`` | ||
- AWS S3: ``s3://`` | ||
- Google Cloud Storage: ``gs://`` | ||
Parameters | ||
---------- | ||
uri : str or Path | ||
URI to open | ||
mode : str, default 'rb' | ||
the file mode | ||
http_auth : AuthBase or tuple of (str, str), optional | ||
Authentication details when using http(s) uri's | ||
http_headers : dict, optional | ||
Extra headers when using http(s) uri's | ||
Return | ||
------ | ||
IO | ||
""" | ||
if isinstance(uri, Path): | ||
return uri.open(mode=mode) | ||
parsed_uri = urlparse(uri) | ||
scheme = parsed_uri.scheme | ||
if not scheme or scheme == "file": | ||
# This is a local file | ||
return open(uri, mode=mode) | ||
elif scheme in ("http", "https"): | ||
headers = {} | ||
headers.update(http_headers or {}) | ||
if "User-Agent" not in headers: | ||
headers["User-Agent"] = "lance" | ||
resp = requests.get(uri, auth=http_auth, headers=headers) | ||
return BytesIO(resp.content) | ||
else: | ||
filesystem, path = fs.FileSystem.from_uri(uri) | ||
return filesystem.open_input_file(path) | ||
|
||
|
||
def copy(source: Union[str, Path], dest: Union[str, Path]) -> str: | ||
"""Copy a file from source to destination, and return the URI of | ||
the copied file. | ||
Parameters | ||
---------- | ||
source : str | ||
The source URI to copy from | ||
dest : str | ||
The destination uri or the destination directory. If ``dest`` is | ||
a URI ends with a "/", it represents a directory. | ||
Return | ||
------ | ||
str | ||
Return the URI of destination. | ||
""" | ||
parsed_source = urlparse(source) | ||
if dest and dest.endswith("/"): | ||
dest = os.path.join(dest, os.path.basename(parsed_source.path)) | ||
parsed_dest = urlparse(dest) | ||
logger.debug("Copying %s to %s", source, dest) | ||
|
||
if parsed_dest.scheme == parsed_source.scheme: | ||
# Direct copy with the same file system | ||
filesystem, source_path = fs.FileSystem.from_uri(str(source)) | ||
_, dest_path = fs.FileSystem.from_uri(str(dest)) | ||
filesystem.copy(source_path, dest_path) | ||
return dest | ||
|
||
source_fs, source_path = fs.FileSystem.from_uri(str(source)) | ||
dest_fs, dest_path = fs.FileSystem.from_uri(str(dest)) | ||
with source_fs.open_input_file(source_path) as in_stream: | ||
with dest_fs.open_output_stream(dest_path) as out_stream: | ||
shutil.copyfileobj(in_stream, out_stream) | ||
return dest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2022. Lance Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Logging utilities""" | ||
|
||
import logging | ||
|
||
_LOG_FORMAT = ( | ||
"%(asctime)s %(levelname)s %(name)s (%(filename)s:%(lineno)d): %(message)s" | ||
) | ||
|
||
__all__ = ["logger"] | ||
|
||
logger = None | ||
|
||
|
||
def _set_logger(level=logging.INFO): | ||
global logger | ||
if logger is not None: | ||
return logger | ||
logger = logging.getLogger("Lance") | ||
logger.setLevel(level) | ||
|
||
handler = logging.StreamHandler() | ||
handler.setFormatter(logging.Formatter(_LOG_FORMAT)) | ||
logger.handlers.clear() | ||
logger.addHandler(handler) | ||
|
||
logger.propagate = False | ||
|
||
|
||
_set_logger() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) 2022. Lance Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import base64 | ||
from pathlib import Path | ||
|
||
import requests | ||
import requests_mock | ||
|
||
from lance.io import open_uri | ||
|
||
WIKIPEDIA = ( | ||
"https://upload.wikimedia.org/wikipedia/commons/a/ad/" | ||
"Commodore_Grace_M._Hopper%2C_USN_%28covered%29.jpg" | ||
) | ||
|
||
|
||
def test_open_https_uri(): | ||
with open_uri(WIKIPEDIA) as fobj: | ||
assert len(fobj.read()) > 0 | ||
|
||
|
||
def test_local(tmp_path: Path): | ||
with open_uri(WIKIPEDIA) as fobj: | ||
img_bytes = fobj.read() | ||
with open_uri(tmp_path / "wikipedia.jpg", mode="wb") as fobj: | ||
fobj.write(img_bytes) | ||
with open_uri(tmp_path / "wikipedia.jpg") as fobj: | ||
assert img_bytes == fobj.read() | ||
|
||
|
||
def test_simple_http_credentials(): | ||
with requests_mock.Mocker() as mock: | ||
mock.get("http://test.com", text="{}") | ||
requests.get("http://test.com", auth=("user", "def_not_pass")) | ||
req = mock.request_history[0] | ||
assert req.headers.get("Authorization") == "Basic {}".format( | ||
base64.b64encode(b"user:def_not_pass").decode("utf-8") | ||
) | ||
|
||
|
||
def test_no_http_credentials(): | ||
with requests_mock.Mocker() as mock: | ||
mock.get("http://test.com", text="{}") | ||
requests.get("http://test.com") | ||
req = mock.request_history[0] | ||
assert "Authorization" not in req.headers |
Oops, something went wrong.