Skip to content

Commit

Permalink
fix: fp16 GPU training (#3105)
Browse files Browse the repository at this point in the history
* cast ivf centroids to the same type as vectors before inference
* cast codebook to the same type as ivf centroid before passing to rust
* fix typing syntax
  • Loading branch information
chebbyChefNEQ authored Nov 8, 2024
1 parent 8155c25 commit 822cf82
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 5 deletions.
2 changes: 1 addition & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ def create_index(
)
values = pa.array(pq_codebook.reshape(-1))
pq_codebook = pa.FixedSizeListArray.from_arrays(
values, num_sub_vectors * 256
values, pq_codebook.shape[2]
)
pq_codebook_batch = pa.RecordBatch.from_arrays(
[pq_codebook], ["_pq_codebook"]
Expand Down
17 changes: 14 additions & 3 deletions python/python/lance/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import re
import tempfile
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Tuple, Union

import pyarrow as pa
from tqdm.auto import tqdm
Expand Down Expand Up @@ -138,7 +138,7 @@ def train_pq_codebook_on_accelerator(
accelerator: Union[str, "torch.Device"],
num_sub_vectors: int,
batch_size: int = 1024 * 10 * 4,
) -> (np.ndarray, List[Any]):
) -> Tuple[np.ndarray, List[Any]]:
"""Use accelerator (GPU or MPS) to train pq codebook."""

from .torch.data import LanceDataset as TorchDataset
Expand Down Expand Up @@ -210,7 +210,7 @@ def train_ivf_centroids_on_accelerator(
sample_rate: int = 256,
max_iters: int = 50,
filter_nan: bool = True,
) -> (np.ndarray, Any):
) -> Tuple[np.ndarray, Any]:
"""Use accelerator (GPU or MPS) to train kmeans."""

from .cuvs.kmeans import KMeans as KMeansCuVS
Expand Down Expand Up @@ -622,6 +622,7 @@ def one_pass_train_ivf_pq_on_accelerator(
pq_codebook, kmeans_list = train_pq_codebook_on_accelerator(
dataset_residuals, metric_type, accelerator, num_sub_vectors, batch_size
)
pq_codebook = pq_codebook.astype(dtype=centroids.dtype)
return centroids, kmeans, pq_codebook, kmeans_list


Expand Down Expand Up @@ -691,6 +692,7 @@ def one_pass_assign_ivf_pq_on_accelerator(

def _partition_and_pq_codes_assignment() -> Iterable[pa.RecordBatch]:
with torch.no_grad():
first_iter = True
for batch in loader:
vecs = (
batch[column]
Expand All @@ -711,6 +713,15 @@ def _partition_and_pq_codes_assignment() -> Iterable[pa.RecordBatch]:
vecs = vecs[mask_gpu]

residual_vecs = vecs - ivf_kmeans.centroids[partitions]
# cast centroids to the same dtype as vecs
if first_iter:
first_iter = False
logging.info("Residual shape: %s", residual_vecs.shape)
for kmeans in pq_kmeans_list:
cents: torch.Tensor = kmeans.centroids
kmeans.centroids = cents.to(
dtype=vecs.dtype, device=ivf_kmeans.device
)
pq_codes = torch.stack(
[
pq_kmeans_list[i].transform(
Expand Down
18 changes: 17 additions & 1 deletion python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
from lance.vector import vec_to_table # noqa: E402


def create_table(nvec=1000, ndim=128, nans=0, nullify=False):
def create_table(nvec=1000, ndim=128, nans=0, nullify=False, dtype=np.float32):
mat = np.random.randn(nvec, ndim)
if nans > 0:
nans_mat = np.empty((nans, ndim))
nans_mat[:] = np.nan
mat = np.concatenate((mat, nans_mat), axis=0)
mat = mat.astype(dtype)
price = np.random.rand(nvec + nans) * 100

def gen_str(n):
Expand Down Expand Up @@ -164,6 +165,21 @@ def test_invalid_subvectors_cuda(tmp_path):
)


@pytest.mark.cuda
def test_f16_cuda(tmp_path):
tbl = create_table(dtype=np.float16)
dataset = lance.write_dataset(tbl, tmp_path)
dataset = dataset.create_index(
"vector",
index_type="IVF_PQ",
num_partitions=4,
num_sub_vectors=16,
accelerator="cuda",
one_pass_ivfpq=True,
)
validate_vector_index(dataset, "vector")


def test_index_with_nans(tmp_path):
# 1024 rows, the entire table should be sampled
tbl = create_table(nvec=1000, nans=24)
Expand Down
3 changes: 3 additions & 0 deletions rust/lance-index/src/vector/ivf/shuffler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ impl IvfShuffler {
location!(),
))?;

info!("Writing unsorted data to disk at {}", path);
info!("with schema: {:?}", schema);

let mut file_writer = FileWriter::<ManifestDescribing>::with_object_writer(
writer,
Schema::try_from(schema.as_ref())?,
Expand Down

0 comments on commit 822cf82

Please sign in to comment.