diff --git a/.github/workflows/pr_local_integration_tests.yml b/.github/workflows/pr_local_integration_tests.yml index 2825b96f48..d3488cd08c 100644 --- a/.github/workflows/pr_local_integration_tests.yml +++ b/.github/workflows/pr_local_integration_tests.yml @@ -50,7 +50,7 @@ jobs: uses: actions/cache@v4 with: path: ${{ steps.uv-cache.outputs.dir }} - key: ${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-uv-${{ hashFiles(format('**/py{0}-ci-requirements.txt', env.PYTHON)) }} + key: ${{ runner.os }}-${{ matrix.python-version }}-uv-${{ hashFiles(format('**/py{0}-ci-requirements.txt', matrix.python-version)) }} - name: Install dependencies run: make install-python-dependencies-ci - name: Test local integration tests diff --git a/Makefile b/Makefile index de2ee568b6..bef7437bc8 100644 --- a/Makefile +++ b/Makefile @@ -268,7 +268,7 @@ test-python-universal-postgres-online: not test_snowflake" \ sdk/python/tests - test-python-universal-mysql-online: +test-python-universal-mysql-online: PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.mysql_online_store.mysql_repo_configuration \ PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.mysql \ @@ -292,7 +292,11 @@ test-python-universal-cassandra: FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.cassandra_online_store.cassandra_repo_configuration \ PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.cassandra \ python -m pytest -x --integration \ - sdk/python/tests + sdk/python/tests/integration/offline_store/test_feature_logging.py \ + --ignore=sdk/python/tests/integration/offline_store/test_validation.py \ + -k "not test_snowflake and \ + not test_spark_materialization_consistency and \ + not test_universal_materialization" test-python-universal-hazelcast: PYTHONPATH='.' \ @@ -330,7 +334,7 @@ test-python-universal-cassandra-no-cloud-providers: not test_snowflake" \ sdk/python/tests - test-python-universal-elasticsearch-online: +test-python-universal-elasticsearch-online: PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.elasticsearch_online_store.elasticsearch_repo_configuration \ PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.elasticsearch \ @@ -349,6 +353,14 @@ test-python-universal-cassandra-no-cloud-providers: not test_snowflake" \ sdk/python/tests +test-python-universal-milvus-online: + PYTHONPATH='.' \ + FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.milvus_online_store.milvus_repo_configuration \ + PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.milvus \ + python -m pytest -n 8 --integration \ + -k "test_retrieve_online_milvus_ocuments" \ + sdk/python/tests --ignore=sdk/python/tests/integration/offline_store/test_dqm_validation.py + test-python-universal-singlestore-online: PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.singlestore_repo_configuration \ diff --git a/docs/how-to-guides/customizing-feast/adding-support-for-a-new-online-store.md b/docs/how-to-guides/customizing-feast/adding-support-for-a-new-online-store.md index 5e26f133ce..ee75aa6b74 100644 --- a/docs/how-to-guides/customizing-feast/adding-support-for-a-new-online-store.md +++ b/docs/how-to-guides/customizing-feast/adding-support-for-a-new-online-store.md @@ -25,7 +25,7 @@ OnlineStore class names must end with the OnlineStore suffix! ### Contrib online stores -New online stores go in `sdk/python/feast/infra/online_stores/contrib/`. +New online stores go in `sdk/python/feast/infra/online_stores/`. #### What is a contrib plugin? diff --git a/sdk/python/docs/source/feast.infra.online_stores.milvus_online_store.rst b/sdk/python/docs/source/feast.infra.online_stores.milvus_online_store.rst index ee9faa55dc..5ae3015bf3 100644 --- a/sdk/python/docs/source/feast.infra.online_stores.milvus_online_store.rst +++ b/sdk/python/docs/source/feast.infra.online_stores.milvus_online_store.rst @@ -4,6 +4,14 @@ feast.infra.online\_stores.milvus\_online\_store package Submodules ---------- +feast.infra.online\_stores.milvus\_online\_store.milvus module +-------------------------------------------------------------- + +.. automodule:: feast.infra.online_stores.milvus_online_store.milvus + :members: + :undoc-members: + :show-inheritance: + feast.infra.online\_stores.milvus\_online\_store.milvus\_repo\_configuration module ----------------------------------------------------------------------------------- diff --git a/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py new file mode 100644 index 0000000000..a1a4a3a5fe --- /dev/null +++ b/sdk/python/feast/infra/online_stores/milvus_online_store/milvus.py @@ -0,0 +1,428 @@ +from datetime import datetime +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union + +from pydantic import StrictStr +from pymilvus import ( + Collection, + CollectionSchema, + DataType, + FieldSchema, + connections, +) +from pymilvus.orm.connections import Connections + +from feast import Entity +from feast.feature_view import FeatureView +from feast.infra.infra_object import InfraObject +from feast.infra.key_encoding_utils import ( + serialize_entity_key, +) +from feast.infra.online_stores.online_store import OnlineStore +from feast.infra.online_stores.vector_store import VectorStoreConfig +from feast.protos.feast.core.InfraObject_pb2 import InfraObject as InfraObjectProto +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.type_map import PROTO_VALUE_TO_VALUE_TYPE_MAP +from feast.types import ( + VALUE_TYPES_TO_FEAST_TYPES, + Array, + ComplexFeastType, + PrimitiveFeastType, + ValueType, +) +from feast.utils import ( + _build_retrieve_online_document_record, + _serialize_vector_to_float_list, + to_naive_utc, +) + +PROTO_TO_MILVUS_TYPE_MAPPING: Dict[ValueType, DataType] = { + PROTO_VALUE_TO_VALUE_TYPE_MAP["bytes_val"]: DataType.VARCHAR, + PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_val"]: DataType.BOOL, + PROTO_VALUE_TO_VALUE_TYPE_MAP["string_val"]: DataType.VARCHAR, + PROTO_VALUE_TO_VALUE_TYPE_MAP["float_val"]: DataType.FLOAT, + PROTO_VALUE_TO_VALUE_TYPE_MAP["double_val"]: DataType.DOUBLE, + PROTO_VALUE_TO_VALUE_TYPE_MAP["int32_val"]: DataType.INT32, + PROTO_VALUE_TO_VALUE_TYPE_MAP["int64_val"]: DataType.INT64, + PROTO_VALUE_TO_VALUE_TYPE_MAP["float_list_val"]: DataType.FLOAT_VECTOR, + PROTO_VALUE_TO_VALUE_TYPE_MAP["int32_list_val"]: DataType.FLOAT_VECTOR, + PROTO_VALUE_TO_VALUE_TYPE_MAP["int64_list_val"]: DataType.FLOAT_VECTOR, + PROTO_VALUE_TO_VALUE_TYPE_MAP["double_list_val"]: DataType.FLOAT_VECTOR, + PROTO_VALUE_TO_VALUE_TYPE_MAP["bool_list_val"]: DataType.BINARY_VECTOR, +} + +FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING: Dict[ + Union[PrimitiveFeastType, Array, ComplexFeastType], DataType +] = {} + +for value_type, feast_type in VALUE_TYPES_TO_FEAST_TYPES.items(): + if isinstance(feast_type, PrimitiveFeastType): + milvus_type = PROTO_TO_MILVUS_TYPE_MAPPING.get(value_type) + if milvus_type: + FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = milvus_type + elif isinstance(feast_type, Array): + base_type = feast_type.base_type + base_value_type = base_type.to_value_type() + if base_value_type in [ + ValueType.INT32, + ValueType.INT64, + ValueType.FLOAT, + ValueType.DOUBLE, + ]: + FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.FLOAT_VECTOR + elif base_value_type == ValueType.STRING: + FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.VARCHAR + elif base_value_type == ValueType.BOOL: + FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING[feast_type] = DataType.BINARY_VECTOR + + +class MilvusOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig): + """ + Configuration for the Milvus online store. + NOTE: The class *must* end with the `OnlineStoreConfig` suffix. + """ + + type: Literal["milvus"] = "milvus" + + host: Optional[StrictStr] = "localhost" + port: Optional[int] = 19530 + index_type: Optional[str] = "IVF_FLAT" + metric_type: Optional[str] = "L2" + embedding_dim: Optional[int] = 128 + vector_enabled: Optional[bool] = True + nlist: Optional[int] = 128 + + +class MilvusOnlineStore(OnlineStore): + """ + Milvus implementation of the online store interface. + + Attributes: + _collections: Dictionary to cache Milvus collections. + """ + + _conn: Optional[Connections] = None + _collections: Dict[str, Collection] = {} + + def _connect(self, config: RepoConfig) -> connections: + if not self._conn: + if not connections.has_connection("feast"): + self._conn = connections.connect( + alias="feast", + host=config.online_store.host, + port=str(config.online_store.port), + ) + return self._conn + + def _get_collection(self, config: RepoConfig, table: FeatureView) -> Collection: + collection_name = _table_id(config.project, table) + if collection_name not in self._collections: + self._connect(config) + + # Create a composite key by combining entity fields + composite_key_name = ( + "_".join([field.name for field in table.entity_columns]) + "_pk" + ) + + fields = [ + FieldSchema( + name=composite_key_name, + dtype=DataType.VARCHAR, + max_length=512, + is_primary=True, + ), + FieldSchema(name="event_ts", dtype=DataType.INT64), + FieldSchema(name="created_ts", dtype=DataType.INT64), + ] + fields_to_exclude = [ + "event_ts", + "created_ts", + ] + fields_to_add = [f for f in table.schema if f.name not in fields_to_exclude] + for field in fields_to_add: + dtype = FEAST_PRIMITIVE_TO_MILVUS_TYPE_MAPPING.get(field.dtype) + if dtype: + if dtype == DataType.FLOAT_VECTOR: + fields.append( + FieldSchema( + name=field.name, + dtype=dtype, + dim=config.online_store.embedding_dim, + ) + ) + elif dtype == DataType.VARCHAR: + fields.append( + FieldSchema( + name=field.name, + dtype=dtype, + max_length=512, + ) + ) + else: + fields.append(FieldSchema(name=field.name, dtype=dtype)) + + schema = CollectionSchema( + fields=fields, description="Feast feature view data" + ) + collection = Collection(name=collection_name, schema=schema, using="feast") + if not collection.has_index(): + index_params = { + "index_type": config.online_store.index_type, + "metric_type": config.online_store.metric_type, + "params": {"nlist": config.online_store.nlist}, + } + for vector_field in schema.fields: + if vector_field.dtype in [ + DataType.FLOAT_VECTOR, + DataType.BINARY_VECTOR, + ]: + collection.create_index( + field_name=vector_field.name, index_params=index_params + ) + collection.load() + self._collections[collection_name] = collection + return self._collections[collection_name] + + def online_write_batch( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[ + EntityKeyProto, + Dict[str, ValueProto], + datetime, + Optional[datetime], + ] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + collection = self._get_collection(config, table) + entity_batch_to_insert = [] + for entity_key, values_dict, timestamp, created_ts in data: + # need to construct the composite primary key also need to handle the fact that entities are a list + entity_key_str = serialize_entity_key( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ).hex() + composite_key_name = ( + "_".join([str(value) for value in entity_key.join_keys]) + "_pk" + ) + timestamp_int = int(to_naive_utc(timestamp).timestamp() * 1e6) + created_ts_int = ( + int(to_naive_utc(created_ts).timestamp() * 1e6) if created_ts else 0 + ) + values_dict = _extract_proto_values_to_dict(values_dict) + entity_dict = _extract_proto_values_to_dict( + dict(zip(entity_key.join_keys, entity_key.entity_values)) + ) + values_dict.update(entity_dict) + + single_entity_record = { + composite_key_name: entity_key_str, + "event_ts": timestamp_int, + "created_ts": created_ts_int, + } + single_entity_record.update(values_dict) + entity_batch_to_insert.append(single_entity_record) + + if progress: + progress(1) + + collection.insert(entity_batch_to_insert) + collection.flush() + + def online_read( + self, + config: RepoConfig, + table: FeatureView, + entity_keys: List[EntityKeyProto], + requested_features: Optional[List[str]] = None, + ) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: + raise NotImplementedError + + def update( + self, + config: RepoConfig, + tables_to_delete: Sequence[FeatureView], + tables_to_keep: Sequence[FeatureView], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + partial: bool, + ): + self._connect(config) + for table in tables_to_keep: + self._get_collection(config, table) + for table in tables_to_delete: + collection_name = _table_id(config.project, table) + collection = Collection(name=collection_name) + if collection.exists(): + collection.drop() + self._collections.pop(collection_name, None) + + def plan( + self, config: RepoConfig, desired_registry_proto: RegistryProto + ) -> List[InfraObject]: + raise NotImplementedError + + def teardown( + self, + config: RepoConfig, + tables: Sequence[FeatureView], + entities: Sequence[Entity], + ): + self._connect(config) + for table in tables: + collection = self._get_collection(config, table) + if collection: + collection.drop() + self._collections.pop(collection.name, None) + + def retrieve_online_documents( + self, + config: RepoConfig, + table: FeatureView, + requested_feature: Optional[str], + requested_features: Optional[List[str]], + embedding: List[float], + top_k: int, + distance_metric: Optional[str] = None, + ) -> List[ + Tuple[ + Optional[datetime], + Optional[EntityKeyProto], + Optional[ValueProto], + Optional[ValueProto], + Optional[ValueProto], + ] + ]: + collection = self._get_collection(config, table) + if not config.online_store.vector_enabled: + raise ValueError("Vector search is not enabled in the online store config") + + search_params = { + "metric_type": distance_metric or config.online_store.metric_type, + "params": {"nprobe": 10}, + } + expr = f"feature_name == '{requested_feature}'" + + composite_key_name = ( + "_".join([str(field.name) for field in table.entity_columns]) + "_pk" + ) + if requested_features: + features_str = ", ".join([f"'{f}'" for f in requested_features]) + expr += f" && feature_name in [{features_str}]" + + output_fields = ( + [composite_key_name] + + (requested_features if requested_features else []) + + ["created_ts", "event_ts"] + ) + assert all( + field + for field in output_fields + if field in [f.name for f in collection.schema.fields] + ), f"field(s) [{[field for field in output_fields if field not in [f.name for f in collection.schema.fields]]}'] not found in collection schema" + + # Note we choose the first vector field as the field to search on. Not ideal but it's something. + ann_search_field = None + for field in collection.schema.fields: + if ( + field.dtype in [DataType.FLOAT_VECTOR, DataType.BINARY_VECTOR] + and field.name in output_fields + ): + ann_search_field = field.name + break + + results = collection.search( + data=[embedding], + anns_field=ann_search_field, + param=search_params, + limit=top_k, + output_fields=output_fields, + consistency_level="Strong", + ) + + result_list = [] + for hits in results: + for hit in hits: + single_record = {} + for field in output_fields: + single_record[field] = hit.entity.get(field) + + entity_key_bytes = bytes.fromhex(hit.entity.get(composite_key_name)) + embedding = hit.entity.get(ann_search_field) + serialized_embedding = _serialize_vector_to_float_list(embedding) + distance = hit.distance + event_ts = datetime.fromtimestamp(hit.entity.get("event_ts") / 1e6) + prepared_result = _build_retrieve_online_document_record( + entity_key_bytes, + # This may have a bug + serialized_embedding.SerializeToString(), + embedding, + distance, + event_ts, + config.entity_key_serialization_version, + ) + result_list.append(prepared_result) + return result_list + + +def _table_id(project: str, table: FeatureView) -> str: + return f"{project}_{table.name}" + + +def _extract_proto_values_to_dict(input_dict: Dict[str, Any]) -> Dict[str, Any]: + numeric_vector_list_types = [ + k + for k in PROTO_VALUE_TO_VALUE_TYPE_MAP.keys() + if k is not None and "list" in k and "string" not in k + ] + output_dict = {} + for feature_name, feature_values in input_dict.items(): + for proto_val_type in PROTO_VALUE_TO_VALUE_TYPE_MAP: + if feature_values.HasField(proto_val_type): + if proto_val_type in numeric_vector_list_types: + vector_values = getattr(feature_values, proto_val_type).val + else: + vector_values = getattr(feature_values, proto_val_type) + output_dict[feature_name] = vector_values + return output_dict + + +class MilvusTable(InfraObject): + """ + A Milvus collection managed by Feast. + + Attributes: + host: The host of the Milvus server. + port: The port of the Milvus server. + name: The name of the collection. + """ + + host: str + port: int + + def __init__(self, host: str, port: int, name: str): + super().__init__(name) + self.host = host + self.port = port + self._connect() + + def _connect(self): + return connections.connect(alias="default", host=self.host, port=str(self.port)) + + def to_infra_object_proto(self) -> InfraObjectProto: + # Implement serialization if needed + raise NotImplementedError + + def update(self): + # Implement update logic if needed + raise NotImplementedError + + def teardown(self): + collection = Collection(name=self.name) + if collection.exists(): + collection.drop() diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index fe34a12adf..2b8d5174e1 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -81,6 +81,7 @@ "singlestore": "feast.infra.online_stores.singlestore_online_store.singlestore.SingleStoreOnlineStore", "qdrant": "feast.infra.online_stores.cqdrant.QdrantOnlineStore", "couchbase": "feast.infra.online_stores.couchbase_online_store.couchbase.CouchbaseOnlineStore", + "milvus": "feast.infra.online_stores.milvus_online_store.milvus.MilvusOnlineStore", **LEGACY_ONLINE_STORE_CLASS_FOR_TYPE, } diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index 8a88c24ffc..000e9cdae4 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -523,6 +523,24 @@ def python_values_to_proto_values( return proto_values +PROTO_VALUE_TO_VALUE_TYPE_MAP: Dict[str, ValueType] = { + "int32_val": ValueType.INT32, + "int64_val": ValueType.INT64, + "double_val": ValueType.DOUBLE, + "float_val": ValueType.FLOAT, + "string_val": ValueType.STRING, + "bytes_val": ValueType.BYTES, + "bool_val": ValueType.BOOL, + "int32_list_val": ValueType.INT32_LIST, + "int64_list_val": ValueType.INT64_LIST, + "double_list_val": ValueType.DOUBLE_LIST, + "float_list_val": ValueType.FLOAT_LIST, + "string_list_val": ValueType.STRING_LIST, + "bytes_list_val": ValueType.BYTES_LIST, + "bool_list_val": ValueType.BOOL_LIST, +} + + def _proto_value_to_value_type(proto_value: ProtoValue) -> ValueType: """ Returns Feast ValueType given Feast ValueType string. @@ -534,25 +552,9 @@ def _proto_value_to_value_type(proto_value: ProtoValue) -> ValueType: A variant of ValueType. """ proto_str = proto_value.WhichOneof("val") - type_map = { - "int32_val": ValueType.INT32, - "int64_val": ValueType.INT64, - "double_val": ValueType.DOUBLE, - "float_val": ValueType.FLOAT, - "string_val": ValueType.STRING, - "bytes_val": ValueType.BYTES, - "bool_val": ValueType.BOOL, - "int32_list_val": ValueType.INT32_LIST, - "int64_list_val": ValueType.INT64_LIST, - "double_list_val": ValueType.DOUBLE_LIST, - "float_list_val": ValueType.FLOAT_LIST, - "string_list_val": ValueType.STRING_LIST, - "bytes_list_val": ValueType.BYTES_LIST, - "bool_list_val": ValueType.BOOL_LIST, - None: ValueType.NULL, - } - - return type_map[proto_str] + if proto_str is None: + return ValueType.UNKNOWN + return PROTO_VALUE_TO_VALUE_TYPE_MAP[proto_str] def pa_to_feast_value_type(pa_type_as_str: str) -> ValueType: diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 570a6d4f8d..3d1f921999 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -150,6 +150,7 @@ def retrieve_online_documents( config: RepoConfig, table: FeatureView, requested_feature: str, + requested_features: Optional[List[str]], query: List[float], top_k: int, distance_metric: Optional[str] = None, diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 4074dcb194..d337d365e9 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -614,6 +614,10 @@ def eventually_apply() -> Tuple[None, bool]: online_features = fs.get_online_features( features=features, entity_rows=entity_rows ).to_dict() + + # Debugging print statement + print("Online features values:", online_features["value"]) + assert all(v is None for v in online_features["value"]) @@ -891,3 +895,28 @@ def test_retrieve_online_documents(vectordb_environment, fake_document_data): top_k=2, distance_metric="wrong", ).to_dict() + + +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["milvus"]) +def test_retrieve_online_milvus_documents(vectordb_environment, fake_document_data): + fs = vectordb_environment.feature_store + df, data_source = fake_document_data + item_embeddings_feature_view = create_item_embeddings_feature_view(data_source) + fs.apply([item_embeddings_feature_view, item()]) + fs.write_to_online_store("item_embeddings", df) + documents = fs.retrieve_online_documents( + feature=None, + features=[ + "item_embeddings:embedding_float", + "item_embeddings:item_id", + "item_embeddings:string_feature", + ], + query=[1.0, 2.0], + top_k=2, + distance_metric="L2", + ).to_dict() + assert len(documents["embedding_float"]) == 2 + + assert len(documents["item_id"]) == 2 + assert documents["item_id"] == [2, 3]