diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index b9eee968ba..11c50b9e8c 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -900,6 +900,7 @@ from .types.reasoning_engine import ReasoningEngineSpec from .types.reasoning_engine_execution_service import QueryReasoningEngineRequest from .types.reasoning_engine_execution_service import QueryReasoningEngineResponse +from .types.reasoning_engine_execution_service import StreamQueryReasoningEngineRequest from .types.reasoning_engine_service import CreateReasoningEngineOperationMetadata from .types.reasoning_engine_service import CreateReasoningEngineRequest from .types.reasoning_engine_service import DeleteReasoningEngineRequest @@ -1845,6 +1846,7 @@ "QueryExtensionResponse", "QueryReasoningEngineRequest", "QueryReasoningEngineResponse", + "StreamQueryReasoningEngineRequest", "QuestionAnsweringCorrectnessInput", "QuestionAnsweringCorrectnessInstance", "QuestionAnsweringCorrectnessResult", diff --git a/google/cloud/aiplatform_v1beta1/gapic_metadata.json b/google/cloud/aiplatform_v1beta1/gapic_metadata.json index 45885c6f3b..7da138f5be 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1beta1/gapic_metadata.json @@ -4682,6 +4682,11 @@ "methods": [ "query_reasoning_engine" ] + }, + "StreamQueryReasoningEngine": { + "methods": [ + "stream_query_reasoning_engine" + ] } } }, @@ -4692,6 +4697,11 @@ "methods": [ "query_reasoning_engine" ] + }, + "StreamQueryReasoningEngine": { + "methods": [ + "stream_query_reasoning_engine" + ] } } }, @@ -4702,6 +4712,11 @@ "methods": [ "query_reasoning_engine" ] + }, + "StreamQueryReasoningEngine": { + "methods": [ + "stream_query_reasoning_engine" + ] } } } diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py index 418d44c4a6..0c9d574882 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py @@ -23,6 +23,7 @@ MutableMapping, MutableSequence, Optional, + Iterable, Sequence, Tuple, Type, @@ -49,6 +50,7 @@ OptionalRetry = Union[retries.Retry, object, None] # type: ignore from google.cloud.aiplatform_v1beta1.types import reasoning_engine_execution_service +from google.api import httpbody_pb2 # type: ignore from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -799,6 +801,134 @@ def sample_query_reasoning_engine(): # Done; return the response. return response + def stream_query_reasoning_engine(self, + request: Optional[Union[reasoning_engine_execution_service.StreamQueryReasoningEngineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Iterable[httpbody_pb2.HttpBody]: + r"""Streams queries using a reasoning engine. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + def sample_stream_query_reasoning_engine(): + # Create a client + client = aiplatform_v1beta1.ReasoningEngineExecutionServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.StreamQueryReasoningEngineRequest( + name="name_value", + ) + + # Make the request + stream = client.stream_query_reasoning_engine(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1beta1.types.StreamQueryReasoningEngineRequest, dict]): + The request object. Request message for + [ReasoningEngineExecutionService.StreamQuery][]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + Iterable[google.api.httpbody_pb2.HttpBody]: + Message that represents an arbitrary HTTP body. It should only be used for + payload formats that can't be represented as JSON, + such as raw binary or an HTML page. + + This message can be used both in streaming and + non-streaming API methods in the request as well as + the response. + + It can be used as a top-level request field, which is + convenient if one wants to extract parameters from + either the URL or HTTP template into the request + fields and also want access to the raw HTTP body. + + Example: + + message GetResourceRequest { + // A unique request id. string request_id = 1; + + // The raw HTTP body is bound to this field. + google.api.HttpBody http_body = 2; + + } + + service ResourceService { + rpc GetResource(GetResourceRequest) + returns (google.api.HttpBody); + + rpc UpdateResource(google.api.HttpBody) + returns (google.protobuf.Empty); + + } + + Example with streaming methods: + + service CaldavService { + rpc GetCalendar(stream google.api.HttpBody) + returns (stream google.api.HttpBody); + + rpc UpdateCalendar(stream google.api.HttpBody) + returns (stream google.api.HttpBody); + + } + + Use of this type only changes how the request and + response bodies are handled, all other features will + continue to work unchanged. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, reasoning_engine_execution_service.StreamQueryReasoningEngineRequest): + request = reasoning_engine_execution_service.StreamQueryReasoningEngineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.stream_query_reasoning_engine] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def __enter__(self) -> "ReasoningEngineExecutionServiceClient": return self diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/base.py index 47a10dd450..c76fd3b768 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/base.py @@ -27,6 +27,7 @@ from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import reasoning_engine_execution_service +from google.api import httpbody_pb2 # type: ignore from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -138,6 +139,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.stream_query_reasoning_engine: gapic_v1.method.wrap_method( + self.stream_query_reasoning_engine, + default_timeout=None, + client_info=client_info, + ), self.get_location: gapic_v1.method.wrap_method( self.get_location, default_timeout=None, @@ -211,6 +217,15 @@ def query_reasoning_engine( ]: raise NotImplementedError() + @property + def stream_query_reasoning_engine(self) -> Callable[ + [reasoning_engine_execution_service.StreamQueryReasoningEngineRequest], + Union[ + httpbody_pb2.HttpBody, + Awaitable[httpbody_pb2.HttpBody] + ]]: + raise NotImplementedError() + @property def list_operations( self, diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py index c25c774018..43cd63c2ad 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py @@ -25,6 +25,7 @@ import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import reasoning_engine_execution_service +from google.api import httpbody_pb2 # type: ignore from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -270,6 +271,32 @@ def query_reasoning_engine( ) return self._stubs["query_reasoning_engine"] + @property + def stream_query_reasoning_engine(self) -> Callable[ + [reasoning_engine_execution_service.StreamQueryReasoningEngineRequest], + httpbody_pb2.HttpBody]: + r"""Return a callable for the stream query reasoning engine method over gRPC. + + Streams queries using a reasoning engine. + + Returns: + Callable[[~.StreamQueryReasoningEngineRequest], + ~.HttpBody]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'stream_query_reasoning_engine' not in self._stubs: + self._stubs['stream_query_reasoning_engine'] = self.grpc_channel.unary_stream( + '/google.cloud.aiplatform.v1beta1.ReasoningEngineExecutionService/StreamQueryReasoningEngine', + request_serializer=reasoning_engine_execution_service.StreamQueryReasoningEngineRequest.serialize, + response_deserializer=httpbody_pb2.HttpBody.FromString, + ) + return self._stubs['stream_query_reasoning_engine'] + def close(self): self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py index 9ceb0520d7..b9cbce3caa 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py @@ -35,6 +35,7 @@ from google.cloud.aiplatform_v1beta1.types import reasoning_engine_execution_service +from google.api import httpbody_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore @@ -77,6 +78,14 @@ def post_query_reasoning_engine(self, response): logging.log(f"Received response: {response}") return response + def pre_stream_query_reasoning_engine(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_stream_query_reasoning_engine(self, response): + logging.log(f"Received response: {response}") + return response + transport = ReasoningEngineExecutionServiceRestTransport(interceptor=MyCustomReasoningEngineExecutionServiceInterceptor()) client = ReasoningEngineExecutionServiceClient(transport=transport) @@ -109,6 +118,23 @@ def post_query_reasoning_engine( """ return response + def pre_stream_query_reasoning_engine(self, request: reasoning_engine_execution_service.StreamQueryReasoningEngineRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[reasoning_engine_execution_service.StreamQueryReasoningEngineRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for stream_query_reasoning_engine + + Override in a subclass to manipulate the request or metadata + before they are sent to the ReasoningEngineExecutionService server. + """ + return request, metadata + + def post_stream_query_reasoning_engine(self, response: rest_streaming.ResponseIterator) -> rest_streaming.ResponseIterator: + """Post-rpc interceptor for stream_query_reasoning_engine + + Override in a subclass to manipulate the response + after it is returned by the ReasoningEngineExecutionService server but before + it is returned to user code. + """ + return response + def pre_get_location( self, request: locations_pb2.GetLocationRequest, @@ -527,6 +553,129 @@ def __call__( resp = self._interceptor.post_query_reasoning_engine(resp) return resp + class _StreamQueryReasoningEngine(_BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine, ReasoningEngineExecutionServiceRestStub): + def __hash__(self): + return hash("ReasoningEngineExecutionServiceRestTransport.StreamQueryReasoningEngine") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None): + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response + + def __call__(self, + request: reasoning_engine_execution_service.StreamQueryReasoningEngineRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> rest_streaming.ResponseIterator: + r"""Call the stream query reasoning + engine method over HTTP. + + Args: + request (~.reasoning_engine_execution_service.StreamQueryReasoningEngineRequest): + The request object. Request message for + [ReasoningEngineExecutionService.StreamQuery][]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.httpbody_pb2.HttpBody: + Message that represents an arbitrary HTTP body. It + should only be used for payload formats that can't be + represented as JSON, such as raw binary or an HTML page. + + This message can be used both in streaming and + non-streaming API methods in the request as well as the + response. + + It can be used as a top-level request field, which is + convenient if one wants to extract parameters from + either the URL or HTTP template into the request fields + and also want access to the raw HTTP body. + + Example: + + :: + + message GetResourceRequest { + // A unique request id. + string request_id = 1; + + // The raw HTTP body is bound to this field. + google.api.HttpBody http_body = 2; + + } + + service ResourceService { + rpc GetResource(GetResourceRequest) + returns (google.api.HttpBody); + rpc UpdateResource(google.api.HttpBody) + returns (google.protobuf.Empty); + + } + + Example with streaming methods: + + :: + + service CaldavService { + rpc GetCalendar(stream google.api.HttpBody) + returns (stream google.api.HttpBody); + rpc UpdateCalendar(stream google.api.HttpBody) + returns (stream google.api.HttpBody); + + } + + Use of this type only changes how the request and + response bodies are handled, all other features will + continue to work unchanged. + + """ + + http_options = _BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine._get_http_options() + request, metadata = self._interceptor.pre_stream_query_reasoning_engine(request, metadata) + transcoded_request = _BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine._get_transcoded_request(http_options, request) + + body = _BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine._get_request_body_json(transcoded_request) + + # Jsonify the query params + query_params = _BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine._get_query_params_json(transcoded_request) + + # Send the request + response = ReasoningEngineExecutionServiceRestTransport._StreamQueryReasoningEngine._get_response(self._host, metadata, query_params, self._session, timeout, transcoded_request, body) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = rest_streaming.ResponseIterator(response, httpbody_pb2.HttpBody) + resp = self._interceptor.post_stream_query_reasoning_engine(resp) + return resp + @property def query_reasoning_engine( self, @@ -538,6 +687,14 @@ def query_reasoning_engine( # In C++ this would require a dynamic_cast return self._QueryReasoningEngine(self._session, self._host, self._interceptor) # type: ignore + @property + def stream_query_reasoning_engine(self) -> Callable[ + [reasoning_engine_execution_service.StreamQueryReasoningEngineRequest], + httpbody_pb2.HttpBody]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._StreamQueryReasoningEngine(self._session, self._host, self._interceptor) # type: ignore + @property def get_location(self): return self._GetLocation(self._session, self._host, self._interceptor) # type: ignore diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest_base.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest_base.py index f265ed3f62..7d80faffed 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest_base.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest_base.py @@ -28,6 +28,7 @@ from google.cloud.aiplatform_v1beta1.types import reasoning_engine_execution_service +from google.api import httpbody_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore @@ -156,6 +157,53 @@ def _get_query_params_json(transcoded_request): query_params["$alt"] = "json;enum-encoding=int" return query_params + class _BaseStreamQueryReasoningEngine: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta1/{name=projects/*/locations/*/reasoningEngines/*}:streamQuery', + 'body': '*', + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = reasoning_engine_execution_service.StreamQueryReasoningEngineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + use_integers_for_enums=True + ) + return body + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + use_integers_for_enums=True, + )) + query_params.update(_BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + class _BaseGetLocation: def __hash__(self): # pragma: NO COVER return NotImplementedError("__hash__ must be implemented.") diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index e9209d5bc2..819f8e21dd 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -997,6 +997,7 @@ from .reasoning_engine_execution_service import ( QueryReasoningEngineRequest, QueryReasoningEngineResponse, + StreamQueryReasoningEngineRequest, ) from .reasoning_engine_service import ( CreateReasoningEngineOperationMetadata, @@ -2025,6 +2026,7 @@ "ReasoningEngineSpec", "QueryReasoningEngineRequest", "QueryReasoningEngineResponse", + "StreamQueryReasoningEngineRequest", "CreateReasoningEngineOperationMetadata", "CreateReasoningEngineRequest", "DeleteReasoningEngineRequest", diff --git a/google/cloud/aiplatform_v1beta1/types/reasoning_engine_execution_service.py b/google/cloud/aiplatform_v1beta1/types/reasoning_engine_execution_service.py index bb5620c39f..ad7367e664 100644 --- a/google/cloud/aiplatform_v1beta1/types/reasoning_engine_execution_service.py +++ b/google/cloud/aiplatform_v1beta1/types/reasoning_engine_execution_service.py @@ -27,6 +27,7 @@ manifest={ "QueryReasoningEngineRequest", "QueryReasoningEngineResponse", + "StreamQueryReasoningEngineRequest", }, ) @@ -43,6 +44,10 @@ class QueryReasoningEngineRequest(proto.Message): Optional. Input content provided by users in JSON object format. Examples include text query, function calling parameters, media bytes, etc. + class_method (str): + Optional. Class method to be used for the + query. It is optional and defaults to "query" if + unspecified. """ name: str = proto.Field( @@ -54,6 +59,10 @@ class QueryReasoningEngineRequest(proto.Message): number=2, message=struct_pb2.Struct, ) + class_method: str = proto.Field( + proto.STRING, + number=3, + ) class QueryReasoningEngineResponse(proto.Message): @@ -72,4 +81,36 @@ class QueryReasoningEngineResponse(proto.Message): ) +class StreamQueryReasoningEngineRequest(proto.Message): + r"""Request message for [ReasoningEngineExecutionService.StreamQuery][]. + + Attributes: + name (str): + Required. The name of the ReasoningEngine resource to use. + Format: + ``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}`` + input (google.protobuf.struct_pb2.Struct): + Optional. Input content provided by users in + JSON object format. Examples include text query, + function calling parameters, media bytes, etc. + class_method (str): + Optional. Class method to be used for the stream query. It + is optional and defaults to "steam_query" if unspecified. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + input: struct_pb2.Struct = proto.Field( + proto.MESSAGE, + number=2, + message=struct_pb2.Struct, + ) + class_method: str = proto.Field( + proto.STRING, + number=3, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/vertexai/reasoning_engines/_reasoning_engines.py b/vertexai/reasoning_engines/_reasoning_engines.py index d9a46fe563..67f67c0efe 100644 --- a/vertexai/reasoning_engines/_reasoning_engines.py +++ b/vertexai/reasoning_engines/_reasoning_engines.py @@ -20,7 +20,7 @@ import sys import tarfile import typing -from typing import Optional, Protocol, Sequence, Union, List +from typing import Iterable, List, Optional, Protocol, Sequence, Union, cast from google.api_core import exceptions from google.cloud import storage @@ -30,6 +30,7 @@ from google.cloud.aiplatform_v1beta1 import types from google.cloud.aiplatform_v1beta1.types import reasoning_engine_service from vertexai.reasoning_engines import _utils +from google.api import httpbody_pb2 # type: ignore from google.protobuf import field_mask_pb2 @@ -50,6 +51,24 @@ def query(self, **kwargs): """Runs the Reasoning Engine to serve the user query.""" +@typing.runtime_checkable +class StreamQueryable(Protocol): + """Protocol for Reasoning Engine applications that can stream responses.""" + + @abc.abstractmethod + def stream_query(self, **kwargs): + """Stream responses to serve the user query.""" + + +@typing.runtime_checkable +class OperationRegistrable(Protocol): + """Protocol for applications that has registered operations.""" + + @abc.abstractmethod + def register_operations(self, **kwargs): + """Register the user provided operations (modes and methods).""" + + @typing.runtime_checkable class Cloneable(Protocol): """Protocol for Reasoning Engine applications that can be cloned.""" @@ -59,7 +78,64 @@ def clone(self): """Return a clone of the object.""" -class ReasoningEngine(base.VertexAiResourceNounWithFutureManager, Queryable): +def _query_operation(name: str, description: str = None): + def _method(self, **kwargs) -> _utils.JsonDict: + response = self.execution_api_client.query_reasoning_engine( + request=types.QueryReasoningEngineRequest( + name=self.resource_name, + input=kwargs, + class_method=name, + ), + ) + output = _utils.to_dict(response) + if "output" in output: + return output.get("output") + return output + _method.__name__ = name + if description: + _method.__doc__ = description + return _method + + +def _stream_query_operation(name: str, description: str = None): + def _method(self, **kwargs) -> Iterable[httpbody_pb2.HttpBody]: + response = self.execution_api_client.stream_query_reasoning_engine( + request=types.StreamQueryReasoningEngineRequest( + name=self.resource_name, + input=kwargs, + class_method=name, + ), + ) + for chunk in response: + yield chunk + _method.__name__ = name + if description: + _method.__doc__ = description + return _method + + +def _generate_methods(obj: "ReasoningEngine"): + spec = _utils.to_dict(obj.gca_resource.spec) + class_methods = spec.get("classMethods") or spec.get("class_methods", []) + for class_method in class_methods: + api_mode = class_method.get("api_mode") + method_name = class_method.get("name") + method_description = class_method.get("description") + if api_mode == "": + query_method = _query_operation( + name=method_name, + description=method_description, + ) + setattr(obj, method_name, query_method.__get__(obj)) + elif api_mode == "stream": + stream_query_method = _stream_query_operation( + name=method_name, + description=method_description, + ) + setattr(obj, method_name, stream_query_method.__get__(obj)) + + +class ReasoningEngine(base.VertexAiResourceNounWithFutureManager): """Represents a Vertex AI Reasoning Engine resource.""" client_class = aip_utils.ReasoningEngineClientWithOverride @@ -84,6 +160,10 @@ def __init__(self, reasoning_engine_name: str): client_class=aip_utils.ReasoningEngineExecutionClientWithOverride, ) self._gca_resource = self._get_gca_resource(resource_name=reasoning_engine_name) + try: + _generate_methods(self) + except Exception as e: + _LOGGER.warning(f"failed to generate methods: {e}") self._operation_schemas = None @property @@ -242,14 +322,32 @@ def create( reasoning_engine_spec = types.ReasoningEngineSpec( package_spec=package_spec, ) - try: - schema_dict = _utils.generate_schema( - reasoning_engine.query, - schema_name=f"{type(reasoning_engine).__name__}_query", + if not isinstance(reasoning_engine, OperationRegistrable): + def register_operations(self): + operations = {} + if isinstance(reasoning_engine, Queryable): + operations[""] = ["query"] + if isinstance(reasoning_engine, StreamQueryable): + operations["stream"] = ["stream_query"] + return operations + reasoning_engine.register_operations = ( + register_operations.__get__(reasoning_engine) ) - reasoning_engine_spec.class_methods.append(_utils.to_proto(schema_dict)) - except Exception as e: - _LOGGER.warning(f"failed to generate schema: {e}") + if not reasoning_engine.register_operations(): + raise ValueError("ReasoningEngine does not support any operations.") + for mode, method_names in reasoning_engine.register_operations().items(): + for method_name in method_names: + try: + schema_dict = _utils.generate_schema( + getattr(reasoning_engine, method_name), + schema_name=method_name, + ) + except Exception as e: + schema_dict = {} + _LOGGER.warning(f"failed to generate schema: {e}") + class_method = _utils.to_proto(schema_dict) + class_method["api_mode"] = mode + reasoning_engine_spec.class_methods.append(class_method) operation_future = sdk_resource.api_client.create_reasoning_engine( parent=initializer.global_config.common_location_path( project=sdk_resource.project, location=sdk_resource.location @@ -279,6 +377,10 @@ def create( credentials=sdk_resource.credentials, location_override=sdk_resource.location, ) + try: + _generate_methods(sdk_resource) + except Exception as e: + _LOGGER.warning(f"failed to generate methods: {e}") sdk_resource._operation_schemas = None return sdk_resource @@ -415,30 +517,6 @@ def operation_schemas(self) -> Sequence[_utils.JsonDict]: self._operation_schemas = spec.get("class_methods", []) return self._operation_schemas - def query(self, **kwargs) -> _utils.JsonDict: - """Runs the Reasoning Engine to serve the user query. - - This will be based on the `.query(...)` method of the python object that - was passed in when creating the Reasoning Engine. - - Args: - **kwargs: - Optional. The arguments of the `.query(...)` method. - - Returns: - dict[str, Any]: The response from serving the user query. - """ - response = self.execution_api_client.query_reasoning_engine( - request=types.QueryReasoningEngineRequest( - name=self.resource_name, - input=kwargs, - ), - ) - output = _utils.to_dict(response) - if "output" in output: - return output.get("output") - return output - def _validate_sys_version_or_raise(sys_version: str) -> None: """Tries to validate the python system version."""