diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index b9eee968ba..11c50b9e8c 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -900,6 +900,7 @@ from .types.reasoning_engine import ReasoningEngineSpec from .types.reasoning_engine_execution_service import QueryReasoningEngineRequest from .types.reasoning_engine_execution_service import QueryReasoningEngineResponse +from .types.reasoning_engine_execution_service import StreamQueryReasoningEngineRequest from .types.reasoning_engine_service import CreateReasoningEngineOperationMetadata from .types.reasoning_engine_service import CreateReasoningEngineRequest from .types.reasoning_engine_service import DeleteReasoningEngineRequest @@ -1845,6 +1846,7 @@ "QueryExtensionResponse", "QueryReasoningEngineRequest", "QueryReasoningEngineResponse", + "StreamQueryReasoningEngineRequest", "QuestionAnsweringCorrectnessInput", "QuestionAnsweringCorrectnessInstance", "QuestionAnsweringCorrectnessResult", diff --git a/google/cloud/aiplatform_v1beta1/gapic_metadata.json b/google/cloud/aiplatform_v1beta1/gapic_metadata.json index 45885c6f3b..7da138f5be 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_metadata.json +++ b/google/cloud/aiplatform_v1beta1/gapic_metadata.json @@ -4682,6 +4682,11 @@ "methods": [ "query_reasoning_engine" ] + }, + "StreamQueryReasoningEngine": { + "methods": [ + "stream_query_reasoning_engine" + ] } } }, @@ -4692,6 +4697,11 @@ "methods": [ "query_reasoning_engine" ] + }, + "StreamQueryReasoningEngine": { + "methods": [ + "stream_query_reasoning_engine" + ] } } }, @@ -4702,6 +4712,11 @@ "methods": [ "query_reasoning_engine" ] + }, + "StreamQueryReasoningEngine": { + "methods": [ + "stream_query_reasoning_engine" + ] } } } diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py index 418d44c4a6..0c9d574882 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/client.py @@ -23,6 +23,7 @@ MutableMapping, MutableSequence, Optional, + Iterable, Sequence, Tuple, Type, @@ -49,6 +50,7 @@ OptionalRetry = Union[retries.Retry, object, None] # type: ignore from google.cloud.aiplatform_v1beta1.types import reasoning_engine_execution_service +from google.api import httpbody_pb2 # type: ignore from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -799,6 +801,134 @@ def sample_query_reasoning_engine(): # Done; return the response. return response + def stream_query_reasoning_engine(self, + request: Optional[Union[reasoning_engine_execution_service.StreamQueryReasoningEngineRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, str]] = (), + ) -> Iterable[httpbody_pb2.HttpBody]: + r"""Streams queries using a reasoning engine. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.cloud import aiplatform_v1beta1 + + def sample_stream_query_reasoning_engine(): + # Create a client + client = aiplatform_v1beta1.ReasoningEngineExecutionServiceClient() + + # Initialize request argument(s) + request = aiplatform_v1beta1.StreamQueryReasoningEngineRequest( + name="name_value", + ) + + # Make the request + stream = client.stream_query_reasoning_engine(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.cloud.aiplatform_v1beta1.types.StreamQueryReasoningEngineRequest, dict]): + The request object. Request message for + [ReasoningEngineExecutionService.StreamQuery][]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + Iterable[google.api.httpbody_pb2.HttpBody]: + Message that represents an arbitrary HTTP body. It should only be used for + payload formats that can't be represented as JSON, + such as raw binary or an HTML page. + + This message can be used both in streaming and + non-streaming API methods in the request as well as + the response. + + It can be used as a top-level request field, which is + convenient if one wants to extract parameters from + either the URL or HTTP template into the request + fields and also want access to the raw HTTP body. + + Example: + + message GetResourceRequest { + // A unique request id. string request_id = 1; + + // The raw HTTP body is bound to this field. + google.api.HttpBody http_body = 2; + + } + + service ResourceService { + rpc GetResource(GetResourceRequest) + returns (google.api.HttpBody); + + rpc UpdateResource(google.api.HttpBody) + returns (google.protobuf.Empty); + + } + + Example with streaming methods: + + service CaldavService { + rpc GetCalendar(stream google.api.HttpBody) + returns (stream google.api.HttpBody); + + rpc UpdateCalendar(stream google.api.HttpBody) + returns (stream google.api.HttpBody); + + } + + Use of this type only changes how the request and + response bodies are handled, all other features will + continue to work unchanged. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, reasoning_engine_execution_service.StreamQueryReasoningEngineRequest): + request = reasoning_engine_execution_service.StreamQueryReasoningEngineRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.stream_query_reasoning_engine] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata(( + ("name", request.name), + )), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + def __enter__(self) -> "ReasoningEngineExecutionServiceClient": return self diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/base.py index 47a10dd450..c76fd3b768 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/base.py @@ -27,6 +27,7 @@ from google.oauth2 import service_account # type: ignore from google.cloud.aiplatform_v1beta1.types import reasoning_engine_execution_service +from google.api import httpbody_pb2 # type: ignore from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -138,6 +139,11 @@ def _prep_wrapped_messages(self, client_info): default_timeout=None, client_info=client_info, ), + self.stream_query_reasoning_engine: gapic_v1.method.wrap_method( + self.stream_query_reasoning_engine, + default_timeout=None, + client_info=client_info, + ), self.get_location: gapic_v1.method.wrap_method( self.get_location, default_timeout=None, @@ -211,6 +217,15 @@ def query_reasoning_engine( ]: raise NotImplementedError() + @property + def stream_query_reasoning_engine(self) -> Callable[ + [reasoning_engine_execution_service.StreamQueryReasoningEngineRequest], + Union[ + httpbody_pb2.HttpBody, + Awaitable[httpbody_pb2.HttpBody] + ]]: + raise NotImplementedError() + @property def list_operations( self, diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py index c25c774018..43cd63c2ad 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/grpc.py @@ -25,6 +25,7 @@ import grpc # type: ignore from google.cloud.aiplatform_v1beta1.types import reasoning_engine_execution_service +from google.api import httpbody_pb2 # type: ignore from google.cloud.location import locations_pb2 # type: ignore from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore @@ -270,6 +271,32 @@ def query_reasoning_engine( ) return self._stubs["query_reasoning_engine"] + @property + def stream_query_reasoning_engine(self) -> Callable[ + [reasoning_engine_execution_service.StreamQueryReasoningEngineRequest], + httpbody_pb2.HttpBody]: + r"""Return a callable for the stream query reasoning engine method over gRPC. + + Streams queries using a reasoning engine. + + Returns: + Callable[[~.StreamQueryReasoningEngineRequest], + ~.HttpBody]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if 'stream_query_reasoning_engine' not in self._stubs: + self._stubs['stream_query_reasoning_engine'] = self.grpc_channel.unary_stream( + '/google.cloud.aiplatform.v1beta1.ReasoningEngineExecutionService/StreamQueryReasoningEngine', + request_serializer=reasoning_engine_execution_service.StreamQueryReasoningEngineRequest.serialize, + response_deserializer=httpbody_pb2.HttpBody.FromString, + ) + return self._stubs['stream_query_reasoning_engine'] + def close(self): self.grpc_channel.close() diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py index 9ceb0520d7..b9cbce3caa 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest.py @@ -35,6 +35,7 @@ from google.cloud.aiplatform_v1beta1.types import reasoning_engine_execution_service +from google.api import httpbody_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore @@ -77,6 +78,14 @@ def post_query_reasoning_engine(self, response): logging.log(f"Received response: {response}") return response + def pre_stream_query_reasoning_engine(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_stream_query_reasoning_engine(self, response): + logging.log(f"Received response: {response}") + return response + transport = ReasoningEngineExecutionServiceRestTransport(interceptor=MyCustomReasoningEngineExecutionServiceInterceptor()) client = ReasoningEngineExecutionServiceClient(transport=transport) @@ -109,6 +118,23 @@ def post_query_reasoning_engine( """ return response + def pre_stream_query_reasoning_engine(self, request: reasoning_engine_execution_service.StreamQueryReasoningEngineRequest, metadata: Sequence[Tuple[str, str]]) -> Tuple[reasoning_engine_execution_service.StreamQueryReasoningEngineRequest, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for stream_query_reasoning_engine + + Override in a subclass to manipulate the request or metadata + before they are sent to the ReasoningEngineExecutionService server. + """ + return request, metadata + + def post_stream_query_reasoning_engine(self, response: rest_streaming.ResponseIterator) -> rest_streaming.ResponseIterator: + """Post-rpc interceptor for stream_query_reasoning_engine + + Override in a subclass to manipulate the response + after it is returned by the ReasoningEngineExecutionService server but before + it is returned to user code. + """ + return response + def pre_get_location( self, request: locations_pb2.GetLocationRequest, @@ -527,6 +553,129 @@ def __call__( resp = self._interceptor.post_query_reasoning_engine(resp) return resp + class _StreamQueryReasoningEngine(_BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine, ReasoningEngineExecutionServiceRestStub): + def __hash__(self): + return hash("ReasoningEngineExecutionServiceRestTransport.StreamQueryReasoningEngine") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None): + + uri = transcoded_request['uri'] + method = transcoded_request['method'] + headers = dict(metadata) + headers['Content-Type'] = 'application/json' + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response + + def __call__(self, + request: reasoning_engine_execution_service.StreamQueryReasoningEngineRequest, *, + retry: OptionalRetry=gapic_v1.method.DEFAULT, + timeout: Optional[float]=None, + metadata: Sequence[Tuple[str, str]]=(), + ) -> rest_streaming.ResponseIterator: + r"""Call the stream query reasoning + engine method over HTTP. + + Args: + request (~.reasoning_engine_execution_service.StreamQueryReasoningEngineRequest): + The request object. Request message for + [ReasoningEngineExecutionService.StreamQuery][]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + ~.httpbody_pb2.HttpBody: + Message that represents an arbitrary HTTP body. It + should only be used for payload formats that can't be + represented as JSON, such as raw binary or an HTML page. + + This message can be used both in streaming and + non-streaming API methods in the request as well as the + response. + + It can be used as a top-level request field, which is + convenient if one wants to extract parameters from + either the URL or HTTP template into the request fields + and also want access to the raw HTTP body. + + Example: + + :: + + message GetResourceRequest { + // A unique request id. + string request_id = 1; + + // The raw HTTP body is bound to this field. + google.api.HttpBody http_body = 2; + + } + + service ResourceService { + rpc GetResource(GetResourceRequest) + returns (google.api.HttpBody); + rpc UpdateResource(google.api.HttpBody) + returns (google.protobuf.Empty); + + } + + Example with streaming methods: + + :: + + service CaldavService { + rpc GetCalendar(stream google.api.HttpBody) + returns (stream google.api.HttpBody); + rpc UpdateCalendar(stream google.api.HttpBody) + returns (stream google.api.HttpBody); + + } + + Use of this type only changes how the request and + response bodies are handled, all other features will + continue to work unchanged. + + """ + + http_options = _BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine._get_http_options() + request, metadata = self._interceptor.pre_stream_query_reasoning_engine(request, metadata) + transcoded_request = _BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine._get_transcoded_request(http_options, request) + + body = _BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine._get_request_body_json(transcoded_request) + + # Jsonify the query params + query_params = _BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine._get_query_params_json(transcoded_request) + + # Send the request + response = ReasoningEngineExecutionServiceRestTransport._StreamQueryReasoningEngine._get_response(self._host, metadata, query_params, self._session, timeout, transcoded_request, body) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = rest_streaming.ResponseIterator(response, httpbody_pb2.HttpBody) + resp = self._interceptor.post_stream_query_reasoning_engine(resp) + return resp + @property def query_reasoning_engine( self, @@ -538,6 +687,14 @@ def query_reasoning_engine( # In C++ this would require a dynamic_cast return self._QueryReasoningEngine(self._session, self._host, self._interceptor) # type: ignore + @property + def stream_query_reasoning_engine(self) -> Callable[ + [reasoning_engine_execution_service.StreamQueryReasoningEngineRequest], + httpbody_pb2.HttpBody]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._StreamQueryReasoningEngine(self._session, self._host, self._interceptor) # type: ignore + @property def get_location(self): return self._GetLocation(self._session, self._host, self._interceptor) # type: ignore diff --git a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest_base.py b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest_base.py index f265ed3f62..7d80faffed 100644 --- a/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest_base.py +++ b/google/cloud/aiplatform_v1beta1/services/reasoning_engine_execution_service/transports/rest_base.py @@ -28,6 +28,7 @@ from google.cloud.aiplatform_v1beta1.types import reasoning_engine_execution_service +from google.api import httpbody_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore @@ -156,6 +157,53 @@ def _get_query_params_json(transcoded_request): query_params["$alt"] = "json;enum-encoding=int" return query_params + class _BaseStreamQueryReasoningEngine: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return {k: v for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() if k not in message_dict} + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [{ + 'method': 'post', + 'uri': '/v1beta1/{name=projects/*/locations/*/reasoningEngines/*}:streamQuery', + 'body': '*', + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = reasoning_engine_execution_service.StreamQueryReasoningEngineRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request['body'], + use_integers_for_enums=True + ) + return body + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json_format.MessageToJson( + transcoded_request['query_params'], + use_integers_for_enums=True, + )) + query_params.update(_BaseReasoningEngineExecutionServiceRestTransport._BaseStreamQueryReasoningEngine._get_unset_required_fields(query_params)) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + class _BaseGetLocation: def __hash__(self): # pragma: NO COVER return NotImplementedError("__hash__ must be implemented.") diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index e9209d5bc2..819f8e21dd 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -997,6 +997,7 @@ from .reasoning_engine_execution_service import ( QueryReasoningEngineRequest, QueryReasoningEngineResponse, + StreamQueryReasoningEngineRequest, ) from .reasoning_engine_service import ( CreateReasoningEngineOperationMetadata, @@ -2025,6 +2026,7 @@ "ReasoningEngineSpec", "QueryReasoningEngineRequest", "QueryReasoningEngineResponse", + "StreamQueryReasoningEngineRequest", "CreateReasoningEngineOperationMetadata", "CreateReasoningEngineRequest", "DeleteReasoningEngineRequest", diff --git a/google/cloud/aiplatform_v1beta1/types/reasoning_engine_execution_service.py b/google/cloud/aiplatform_v1beta1/types/reasoning_engine_execution_service.py index bb5620c39f..ad7367e664 100644 --- a/google/cloud/aiplatform_v1beta1/types/reasoning_engine_execution_service.py +++ b/google/cloud/aiplatform_v1beta1/types/reasoning_engine_execution_service.py @@ -27,6 +27,7 @@ manifest={ "QueryReasoningEngineRequest", "QueryReasoningEngineResponse", + "StreamQueryReasoningEngineRequest", }, ) @@ -43,6 +44,10 @@ class QueryReasoningEngineRequest(proto.Message): Optional. Input content provided by users in JSON object format. Examples include text query, function calling parameters, media bytes, etc. + class_method (str): + Optional. Class method to be used for the + query. It is optional and defaults to "query" if + unspecified. """ name: str = proto.Field( @@ -54,6 +59,10 @@ class QueryReasoningEngineRequest(proto.Message): number=2, message=struct_pb2.Struct, ) + class_method: str = proto.Field( + proto.STRING, + number=3, + ) class QueryReasoningEngineResponse(proto.Message): @@ -72,4 +81,36 @@ class QueryReasoningEngineResponse(proto.Message): ) +class StreamQueryReasoningEngineRequest(proto.Message): + r"""Request message for [ReasoningEngineExecutionService.StreamQuery][]. + + Attributes: + name (str): + Required. The name of the ReasoningEngine resource to use. + Format: + ``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}`` + input (google.protobuf.struct_pb2.Struct): + Optional. Input content provided by users in + JSON object format. Examples include text query, + function calling parameters, media bytes, etc. + class_method (str): + Optional. Class method to be used for the stream query. It + is optional and defaults to "steam_query" if unspecified. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + input: struct_pb2.Struct = proto.Field( + proto.MESSAGE, + number=2, + message=struct_pb2.Struct, + ) + class_method: str = proto.Field( + proto.STRING, + number=3, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/setup.py b/setup.py index 8647c06779..6992e82f93 100644 --- a/setup.py +++ b/setup.py @@ -152,6 +152,7 @@ ] langchain_extra_require = [ + "langgraph >= 0.2.45, < 0.3", "langchain >= 0.1.16, < 0.4", "langchain-core < 0.4", "langchain-google-vertexai < 3", diff --git a/vertexai/preview/reasoning_engines/__init__.py b/vertexai/preview/reasoning_engines/__init__.py index 2568531aed..8b472b59f6 100644 --- a/vertexai/preview/reasoning_engines/__init__.py +++ b/vertexai/preview/reasoning_engines/__init__.py @@ -23,9 +23,13 @@ from vertexai.preview.reasoning_engines.templates.langchain import ( LangchainAgent, ) +from vertexai.preview.reasoning_engines.templates.langgraph import ( + LanggraphAgent, +) __all__ = ( "LangchainAgent", + "LanggraphAgent", "Queryable", "ReasoningEngine", ) diff --git a/vertexai/preview/reasoning_engines/templates/langgraph.py b/vertexai/preview/reasoning_engines/templates/langgraph.py new file mode 100644 index 0000000000..3d62c12e3b --- /dev/null +++ b/vertexai/preview/reasoning_engines/templates/langgraph.py @@ -0,0 +1,505 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Mapping, + Optional, + Sequence, + Union, +) + +if TYPE_CHECKING: + try: + from langchain_core import runnables + from langchain_core import tools as lc_tools + from langchain_core.language_models import base as lc_language_models + + BaseTool = lc_tools.BaseTool + BaseLanguageModel = lc_language_models.BaseLanguageModel + RunnableConfig = runnables.RunnableConfig + RunnableSerializable = runnables.RunnableSerializable + except ImportError: + BaseTool = Any + BaseLanguageModel = Any + RunnableConfig = Any + RunnableSerializable = Any + + try: + from langchain_google_vertexai.functions_utils import _ToolsType + + _ToolLike = _ToolsType + except ImportError: + _ToolLike = Any + + try: + from opentelemetry.sdk import trace + + TracerProvider = trace.TracerProvider + SpanProcessor = trace.SpanProcessor + SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor + except ImportError: + TracerProvider = Any + SpanProcessor = Any + SynchronousMultiSpanProcessor = Any + + +def _default_model_builder( + model_name: str, + *, + project: str, + location: str, + model_kwargs: Optional[Mapping[str, Any]] = None, +) -> "BaseLanguageModel": + import vertexai + from google.cloud.aiplatform import initializer + from langchain_google_vertexai import ChatVertexAI + + model_kwargs = model_kwargs or {} + current_project = initializer.global_config.project + current_location = initializer.global_config.location + vertexai.init(project=project, location=location) + model = ChatVertexAI(model_name=model_name, **model_kwargs) + vertexai.init(project=current_project, location=current_location) + return model + + +def _default_runnable_builder( + model: "BaseLanguageModel", + *, + tools: Optional[Sequence["_ToolLike"]] = None, + checkpointer = None, + model_tool_kwargs: Optional[Mapping[str, Any]] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, +) -> "RunnableSerializable": + from langgraph import prebuilt as langgraph_prebuilt + + model_tool_kwargs = model_tool_kwargs or {} + runnable_kwargs = runnable_kwargs or {} + if tools: + model = model.bind_tools(tools=tools, **model_tool_kwargs) + else: + tools = [] + if checkpointer: + if "checkpointer" in runnable_kwargs: + from google.cloud.aiplatform import base + + base.Logger(__name__).warning( + "checkpointer is being specified in both checkpointer_builder " + "and runnable_kwargs. Please specify it in only one of them. " + "Overriding the checkpointer in runnable_kwargs." + ) + runnable_kwargs["checkpointer"] = checkpointer + return langgraph_prebuilt.create_react_agent( + model, + tools=tools, + **runnable_kwargs, + ) + + +def _validate_callable_parameters_are_annotated(callable: Callable): + """Validates that the parameters of the callable have type annotations. + + This ensures that they can be used for constructing LangChain tools that are + usable with Gemini function calling. + """ + import inspect + + parameters = dict(inspect.signature(callable).parameters) + for name, parameter in parameters.items(): + if parameter.annotation == inspect.Parameter.empty: + raise TypeError( + f"Callable={callable.__name__} has untyped input_arg={name}. " + f"Please specify a type when defining it, e.g. `{name}: str`." + ) + + +def _validate_tools(tools: Sequence["_ToolLike"]): + """Validates that the tools are usable for tool calling.""" + for tool in tools: + if isinstance(tool, Callable): + _validate_callable_parameters_are_annotated(tool) + + +def _override_active_span_processor( + tracer_provider: "TracerProvider", + active_span_processor: "SynchronousMultiSpanProcessor", +): + """Overrides the active span processor. + + When working with multiple LangchainAgents in the same environment, + it's crucial to manage trace exports carefully. + Each agent needs its own span processor tied to a unique project ID. + While we add a new span processor for each agent, this can lead to + unexpected behavior. + For instance, with two agents linked to different projects, traces from the + second agent might be sent to both projects. + To prevent this and guarantee traces go to the correct project, we overwrite + the active span processor whenever a new LangchainAgent is created. + + Args: + tracer_provider (TracerProvider): + The tracer provider to use for the project. + active_span_processor (SynchronousMultiSpanProcessor): + The active span processor overrides the tracer provider's + active span processor. + """ + if tracer_provider._active_span_processor: + tracer_provider._active_span_processor.shutdown() + tracer_provider._active_span_processor = active_span_processor + + +class LanggraphAgent: + """A LangGraph Agent. + """ + + def __init__( + self, + model: str, + *, + tools: Optional[Sequence["_ToolLike"]] = None, + model_kwargs: Optional[Mapping[str, Any]] = None, + model_tool_kwargs: Optional[Mapping[str, Any]] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, + checkpointer_kwargs: Optional[Mapping[str, Any]] = None, + model_builder: Optional[Callable] = None, + runnable_builder: Optional[Callable] = None, + checkpointer_builder: Optional[Callable] = None, + enable_tracing: bool = False, + ): + """Initializes the LangGraph Agent. + + Under-the-hood, assuming .set_up() is called, this will correspond to + + ``` + model = model_builder(model_name=model, model_kwargs=model_kwargs) + runnable = runnable_builder( + model=model, + tools=tools, + model_tool_kwargs=model_tool_kwargs, + runnable_kwargs=runnable_kwargs, + ) + ``` + + When everything is based on their default values, this corresponds to + ``` + # model_builder + from langchain_google_vertexai import ChatVertexAI + llm = ChatVertexAI(model_name=model, **model_kwargs) + + # runnable_builder + from langchain import agents + from langchain_core.runnables.history import RunnableWithMessageHistory + llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs) + agent_executor = agents.AgentExecutor( + agent=prompt | llm_with_tools | output_parser, + tools=tools, + **agent_executor_kwargs, + ) + runnable = RunnableWithMessageHistory( + runnable=agent_executor, + get_session_history=chat_history, + **runnable_kwargs, + ) + ``` + + Args: + model (str): + Optional. The name of the model (e.g. "gemini-1.0-pro"). + tools (Sequence[langchain_core.tools.BaseTool, Callable]): + Optional. The tools for the agent to be able to use. All input + callables (e.g. function or class method) will be converted + to a langchain.tools.base.StructuredTool. Defaults to None. + model_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments for the constructor of + chat_models.ChatVertexAI. An example would be + ``` + { + # temperature (float): Sampling temperature, it controls the + # degree of randomness in token selection. + "temperature": 0.28, + # max_output_tokens (int): Token limit determines the + # maximum amount of text output from one prompt. + "max_output_tokens": 1000, + # top_p (float): Tokens are selected from most probable to + # least, until the sum of their probabilities equals the + # top_p value. + "top_p": 0.95, + # top_k (int): How the model selects tokens for output, the + # next token is selected from among the top_k most probable + # tokens. + "top_k": 40, + } + ``` + model_tool_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments when binding tools to the + model using `model.bind_tools()`. + runnable_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments for the constructor of + langchain.runnables.history.RunnableWithMessageHistory if + chat_history is specified. If chat_history is None, this will be + ignored. + model_builder (Callable): + Optional. Callable that returns a new language model. Defaults + to a a callable that returns ChatVertexAI based on `model`, + `model_kwargs` and the parameters in `vertexai.init`. + runnable_builder (Callable): + Optional. Callable that returns a new runnable. This can be used + for customizing the orchestration logic of the Agent based on + the model returned by `model_builder` and the rest of the input + arguments. + checkpointer_builder (Callable): + Optional. Callable that returns a checkpointer. This can be used + for defining the checkpointer of the Agent. Defaults to None. + enable_tracing (bool): + Optional. Whether to enable tracing in Cloud Trace. Defaults to + False. + + Raises: + TypeError: If there is an invalid tool (e.g. function with an input + that did not specify its type). + """ + from google.cloud.aiplatform import initializer + + self._project = initializer.global_config.project + self._location = initializer.global_config.location + self._tools = [] + if tools: + # We validate tools at initialization for actionable feedback before + # they are deployed. + _validate_tools(tools) + self._tools = tools + self._model_name = model + self._model_kwargs = model_kwargs + self._model_tool_kwargs = model_tool_kwargs + self._runnable_kwargs = runnable_kwargs + self._checkpointer_kwargs = checkpointer_kwargs + self._model = None + self._model_builder = model_builder + self._runnable = None + self._runnable_builder = runnable_builder + self._checkpointer_builder = checkpointer_builder + self._instrumentor = None + self._enable_tracing = enable_tracing + + def set_up(self): + """Sets up the agent for execution of queries at runtime. + + It initializes the model, binds the model with tools, and connects it + with the prompt template and output parser. + + This method should not be called for an object that being passed to + the ReasoningEngine service for deployment, as it initializes clients + that can not be serialized. + """ + if self._enable_tracing: + from vertexai.reasoning_engines import _utils + + cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn() + cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn() + openinference_langchain = _utils._import_openinference_langchain_or_warn() + opentelemetry = _utils._import_opentelemetry_or_warn() + opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn() + if all( + ( + cloud_trace_exporter, + cloud_trace_v2, + openinference_langchain, + opentelemetry, + opentelemetry_sdk_trace, + ) + ): + import google.auth + + credentials, _ = google.auth.default() + span_exporter = cloud_trace_exporter.CloudTraceSpanExporter( + project_id=self._project, + client=cloud_trace_v2.TraceServiceClient( + credentials=credentials.with_quota_project(self._project), + ), + ) + span_processor: SpanProcessor = ( + opentelemetry_sdk_trace.export.SimpleSpanProcessor( + span_exporter=span_exporter, + ) + ) + tracer_provider: TracerProvider = ( + opentelemetry.trace.get_tracer_provider() + ) + # Get the appropriate tracer provider: + # 1. If _TRACER_PROVIDER is already set, use that. + # 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment + # variable is set, use that. + # 3. As a final fallback, use _PROXY_TRACER_PROVIDER. + # If none of the above is set, we log a warning, and + # create a tracer provider. + if not tracer_provider: + from google.cloud.aiplatform import base + + base.Logger(__name__).warning( + "No tracer provider. By default, " + "we should get one of the following providers: " + "OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, " + "or _PROXY_TRACER_PROVIDER." + ) + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids AttributeError: + # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no + # attribute 'add_span_processor'. + if _utils.is_noop_or_proxy_tracer_provider(tracer_provider): + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids OpenTelemetry client already exists error. + _override_active_span_processor( + tracer_provider, + opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(), + ) + tracer_provider.add_span_processor(span_processor) + # Keep the instrumentation up-to-date. + # When creating multiple LangchainAgents, + # we need to keep the instrumentation up-to-date. + # We deliberately override the instrument each time, + # so that if different agents end up using different + # instrumentations, we guarantee that the user is always + # working with the most recent agent's instrumentation. + self._instrumentor = openinference_langchain.LangChainInstrumentor() + if self._instrumentor.is_instrumented_by_opentelemetry: + self._instrumentor.uninstrument() + self._instrumentor.instrument() + else: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "enable_tracing=True but proceeding with tracing disabled " + "because not all packages for tracing have been installed" + ) + model_builder = self._model_builder or _default_model_builder + self._model = model_builder( + model_name=self._model_name, + model_kwargs=self._model_kwargs, + project=self._project, + location=self._location, + ) + self._checkpointer = None + if self._checkpointer_builder: + checkpointer_kwargs = self._checkpointer_kwargs or {} + self._checkpointer = self._checkpointer_builder( + **checkpointer_kwargs, + ) + runnable_builder = self._runnable_builder or _default_runnable_builder + self._runnable = runnable_builder( + model=self._model, + tools=self._tools, + checkpointer=self._checkpointer, + model_tool_kwargs=self._model_tool_kwargs, + runnable_kwargs=self._runnable_kwargs, + ) + + def clone(self) -> "LanggraphAgent": + """Returns a clone of the LanggraphAgent.""" + import copy + + return LanggraphAgent( + model=self._model_name, + tools=copy.deepcopy(self._tools), + model_kwargs=copy.deepcopy(self._model_kwargs), + model_tool_kwargs=copy.deepcopy(self._model_tool_kwargs), + runnable_kwargs=copy.deepcopy(self._runnable_kwargs), + checkpointer_kwargs=copy.deepcopy(self._checkpointer_kwargs), + model_builder=self._model_builder, + runnable_builder=self._runnable_builder, + checkpointer_builder=self._checkpointer_builder, + enable_tracing=self._enable_tracing, + ) + + def query( + self, + *, + input: Union[str, Mapping[str, Any]], + config: Optional["RunnableConfig"] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Queries the Agent with the given input and config. + + Args: + input (Union[str, Mapping[str, Any]]): + Required. The input to be passed to the Agent. + config (langchain_core.runnables.RunnableConfig): + Optional. The config (if any) to be used for invoking the Agent. + **kwargs: + Optional. Any additional keyword arguments to be passed to the + `.invoke()` method of the corresponding AgentExecutor. + + Returns: + The output of querying the Agent with the given input and config. + """ + from langchain.load import dump as langchain_load_dump + + if isinstance(input, str): + input = {"input": input} + if not self._runnable: + self.set_up() + return langchain_load_dump.dumpd( + self._runnable.invoke(input=input, config=config, **kwargs) + ) + + def stream_query( + self, + *, + input, + config = None, + **kwargs, + ): + from langchain.load import dump as langchain_load_dump + + if not self._runnable: + self.set_up() + for chunk in self._runnable.stream(input=input, config=config, **kwargs): + yield langchain_load_dump.dumpd(chunk) + + def get_state_history( + self, + config, + **kwargs, + ): + if not self._runnable: + self.set_up() + for state_snapshot in self._runnable.get_state_history(config=config, **kwargs): + yield state_snapshot._asdict() + + def get_state( + self, + config, + **kwargs, + ): + if not self._runnable: + self.set_up() + return self._runnable.get_state(config=config, **kwargs)._asdict() + + def update_state(self, config, **kwargs): + if not self._runnable: + self.set_up() + return self._runnable.update_state(config=config, **kwargs) + + def register_operations(self): + return { + "": ["query", "get_state", "update_state"], + "stream": ["stream_query", "get_state_history"], + } diff --git a/vertexai/reasoning_engines/_reasoning_engines.py b/vertexai/reasoning_engines/_reasoning_engines.py index d9a46fe563..046ae1b4ae 100644 --- a/vertexai/reasoning_engines/_reasoning_engines.py +++ b/vertexai/reasoning_engines/_reasoning_engines.py @@ -16,11 +16,12 @@ import abc import inspect import io +import json import os import sys import tarfile import typing -from typing import Optional, Protocol, Sequence, Union, List +from typing import Any, Iterable, List, Optional, Protocol, Sequence, Union from google.api_core import exceptions from google.cloud import storage @@ -50,6 +51,24 @@ def query(self, **kwargs): """Runs the Reasoning Engine to serve the user query.""" +@typing.runtime_checkable +class StreamQueryable(Protocol): + """Protocol for Reasoning Engine applications that can stream responses.""" + + @abc.abstractmethod + def stream_query(self, **kwargs): + """Stream responses to serve the user query.""" + + +@typing.runtime_checkable +class OperationRegistrable(Protocol): + """Protocol for applications that has registered operations.""" + + @abc.abstractmethod + def register_operations(self, **kwargs): + """Register the user provided operations (modes and methods).""" + + @typing.runtime_checkable class Cloneable(Protocol): """Protocol for Reasoning Engine applications that can be cloned.""" @@ -59,7 +78,97 @@ def clone(self): """Return a clone of the object.""" -class ReasoningEngine(base.VertexAiResourceNounWithFutureManager, Queryable): +def _query_operation(name: str, description: str = None): + def _method(self, **kwargs) -> _utils.JsonDict: + response = self.execution_api_client.query_reasoning_engine( + request=types.QueryReasoningEngineRequest( + name=self.resource_name, + input=kwargs, + class_method=name, + ), + ) + output = _utils.to_dict(response) + if "output" in output: + return output.get("output") + return output + _method.__name__ = name + if description: + _method.__doc__ = description + return _method + + +def _stream_query_operation(name: str, description: str = None): + def _method(self, **kwargs) -> Iterable[Any]: + response = self.execution_api_client.stream_query_reasoning_engine( + request=types.StreamQueryReasoningEngineRequest( + name=self.resource_name, + input=kwargs, + class_method=name, + ), + ) + for chunk in response: + result = chunk + if "application/json" in chunk.content_type: + if not chunk.data: + # To discard the last chunk (which has no data). + continue + try: + result = chunk.data.decode("utf-8") + except Exception as e: + _LOGGER.warning(f"failed to decode: {chunk.data}. Exception: {e}") + yield result + continue + try: + result = json.loads(result) + except Exception as e: + if "Extra data: line 2 column 1" in str(e): + # Handle the case where the chunk is a namedtuple that + # contains multiple dictionaries delimited by newlines. + for line in result.split("\n"): + if line: + try: + line = json.loads(line) + except Exception as e: + _LOGGER.warning( + f"failed to parse json: {result}. Exception: {e}" + ) + yield line + # Continue since we have yielded the content. + continue + else: + _LOGGER.warning( + f"failed to parse json: {result}. Exception: {e}" + ) + yield result + + _method.__name__ = name + if description: + _method.__doc__ = description + return _method + + +def _generate_methods(obj: "ReasoningEngine"): + spec = _utils.to_dict(obj.gca_resource.spec) + class_methods = spec.get("classMethods") or spec.get("class_methods", []) + for class_method in class_methods: + api_mode = class_method.get("api_mode") + method_name = class_method.get("name") + method_description = class_method.get("description") + if api_mode == "": + query_method = _query_operation( + name=method_name, + description=method_description, + ) + setattr(obj, method_name, query_method.__get__(obj)) + elif api_mode == "stream": + stream_query_method = _stream_query_operation( + name=method_name, + description=method_description, + ) + setattr(obj, method_name, stream_query_method.__get__(obj)) + + +class ReasoningEngine(base.VertexAiResourceNounWithFutureManager): """Represents a Vertex AI Reasoning Engine resource.""" client_class = aip_utils.ReasoningEngineClientWithOverride @@ -84,6 +193,10 @@ def __init__(self, reasoning_engine_name: str): client_class=aip_utils.ReasoningEngineExecutionClientWithOverride, ) self._gca_resource = self._get_gca_resource(resource_name=reasoning_engine_name) + try: + _generate_methods(self) + except Exception as e: + _LOGGER.warning(f"failed to generate methods: {e}") self._operation_schemas = None @property @@ -242,14 +355,32 @@ def create( reasoning_engine_spec = types.ReasoningEngineSpec( package_spec=package_spec, ) - try: - schema_dict = _utils.generate_schema( - reasoning_engine.query, - schema_name=f"{type(reasoning_engine).__name__}_query", + if not isinstance(reasoning_engine, OperationRegistrable): + def register_operations(self): + operations = {} + if isinstance(reasoning_engine, Queryable): + operations[""] = ["query"] + if isinstance(reasoning_engine, StreamQueryable): + operations["stream"] = ["stream_query"] + return operations + reasoning_engine.register_operations = ( + register_operations.__get__(reasoning_engine) ) - reasoning_engine_spec.class_methods.append(_utils.to_proto(schema_dict)) - except Exception as e: - _LOGGER.warning(f"failed to generate schema: {e}") + if not reasoning_engine.register_operations(): + raise ValueError("ReasoningEngine does not support any operations.") + for mode, method_names in reasoning_engine.register_operations().items(): + for method_name in method_names: + try: + schema_dict = _utils.generate_schema( + getattr(reasoning_engine, method_name), + schema_name=method_name, + ) + except Exception as e: + schema_dict = {} + _LOGGER.warning(f"failed to generate schema: {e}") + class_method = _utils.to_proto(schema_dict) + class_method["api_mode"] = mode + reasoning_engine_spec.class_methods.append(class_method) operation_future = sdk_resource.api_client.create_reasoning_engine( parent=initializer.global_config.common_location_path( project=sdk_resource.project, location=sdk_resource.location @@ -279,6 +410,10 @@ def create( credentials=sdk_resource.credentials, location_override=sdk_resource.location, ) + try: + _generate_methods(sdk_resource) + except Exception as e: + _LOGGER.warning(f"failed to generate methods: {e}") sdk_resource._operation_schemas = None return sdk_resource @@ -415,30 +550,6 @@ def operation_schemas(self) -> Sequence[_utils.JsonDict]: self._operation_schemas = spec.get("class_methods", []) return self._operation_schemas - def query(self, **kwargs) -> _utils.JsonDict: - """Runs the Reasoning Engine to serve the user query. - - This will be based on the `.query(...)` method of the python object that - was passed in when creating the Reasoning Engine. - - Args: - **kwargs: - Optional. The arguments of the `.query(...)` method. - - Returns: - dict[str, Any]: The response from serving the user query. - """ - response = self.execution_api_client.query_reasoning_engine( - request=types.QueryReasoningEngineRequest( - name=self.resource_name, - input=kwargs, - ), - ) - output = _utils.to_dict(response) - if "output" in output: - return output.get("output") - return output - def _validate_sys_version_or_raise(sys_version: str) -> None: """Tries to validate the python system version."""