Skip to content

Commit

Permalink
Resolve google vertex ai deprecations in tests (#40628)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirrao authored Jul 6, 2024
1 parent 2eda737 commit 9255167
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 57 deletions.
11 changes: 0 additions & 11 deletions tests/deprecations_ignore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,6 @@
- tests/providers/google/cloud/operators/test_kubernetes_engine.py::TestGoogleCloudPlatformContainerOperator::test_create_execute_error_body
- tests/providers/google/cloud/operators/test_life_sciences.py::TestLifeSciencesRunPipelineOperator::test_executes
- tests/providers/google/cloud/operators/test_life_sciences.py::TestLifeSciencesRunPipelineOperator::test_executes_without_project_id
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAICreateBatchPredictionJobOperator::test_execute
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAICreateBatchPredictionJobOperator::test_execute_deferrable
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAICreateHyperparameterTuningJobOperator::test_deferrable
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAICreateHyperparameterTuningJobOperator::test_deferrable_sync_error
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAICreateHyperparameterTuningJobOperator::test_execute
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAIDeleteAutoMLTrainingJobOperator::test_execute
- tests/providers/google/cloud/operators/test_vertex_ai.py::TestVertexAIDeleteCustomTrainingJobOperator::test_execute
- tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py::TestVertexAIPromptLanguageModelOperator::test_execute
- tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py::TestVertexAIGenerateTextEmbeddingsOperator::test_execute
- tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py::TestVertexAIPromptMultimodalModelOperator::test_execute
- tests/providers/google/cloud/operators/vertex_ai/test_generative_model.py::TestVertexAIPromptMultimodalModelWithMediaOperator::test_execute
- tests/providers/google/cloud/secrets/test_secret_manager.py::TestCloudSecretManagerBackend::test_connections_prefix_none_value
- tests/providers/google/cloud/secrets/test_secret_manager.py::TestCloudSecretManagerBackend::test_get_conn_uri
- tests/providers/google/cloud/secrets/test_secret_manager.py::TestCloudSecretManagerBackend::test_get_conn_uri_non_existent_key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,23 @@ def test_deprecation_warning(self):

@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
def test_execute(self, mock_hook):
op = PromptLanguageModelOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
top_p=self.top_p,
top_k=self.top_k,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
with pytest.warns(
AirflowProviderDeprecationWarning,
match=r"Call to deprecated class PromptLanguageModelOperator. \(This operator is deprecated and will be removed after 01.01.2025, please use `TextGenerationModelPredictOperator`.\)",
):
op = PromptLanguageModelOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
top_p=self.top_p,
top_k=self.top_k,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
Expand Down Expand Up @@ -128,15 +132,19 @@ def test_deprecation_warning(self):

@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
def test_execute(self, mock_hook):
op = GenerateTextEmbeddingsOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
with pytest.warns(
AirflowProviderDeprecationWarning,
match=r"Call to deprecated class GenerateTextEmbeddingsOperator. \(This operator is deprecated and will be removed after 01.01.2025, please use `TextEmbeddingModelGetEmbeddingsOperator`.\)",
):
op = GenerateTextEmbeddingsOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
Expand Down Expand Up @@ -178,17 +186,21 @@ def test_deprecation_warning(self):

@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
def test_execute(self, mock_hook):
op = PromptMultimodalModelOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.prompt,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
pretrained_model=self.pretrained_model,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
with pytest.warns(
AirflowProviderDeprecationWarning,
match=r"Call to deprecated class PromptMultimodalModelOperator. \(This operator is deprecated and will be removed after 01.01.2025, please use `GenerativeModelGenerateContentOperator`.\)",
):
op = PromptMultimodalModelOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.prompt,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
pretrained_model=self.pretrained_model,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
Expand Down Expand Up @@ -236,19 +248,23 @@ def test_deprecation_warning(self):

@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
def test_execute(self, mock_hook):
op = PromptMultimodalModelWithMediaOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.vision_prompt,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
pretrained_model=self.pretrained_model,
media_gcs_path=self.media_gcs_path,
mime_type=self.mime_type,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
with pytest.warns(
AirflowProviderDeprecationWarning,
match=r"Call to deprecated class PromptMultimodalModelWithMediaOperator. \(This operator is deprecated and will be removed after 01.01.2025, please use `GenerativeModelGenerateContentOperator`.\)",
):
op = PromptMultimodalModelWithMediaOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.vision_prompt,
generation_config=self.generation_config,
safety_settings=self.safety_settings,
pretrained_model=self.pretrained_model,
media_gcs_path=self.media_gcs_path,
mime_type=self.mime_type,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
Expand Down

0 comments on commit 9255167

Please sign in to comment.