From 700553e628b8d153923360a20c2c9babe32a53a7 Mon Sep 17 00:00:00 2001 From: Siddesh M G <108360375+vinay2242g@users.noreply.github.com> Date: Fri, 10 Jan 2025 13:01:06 +0530 Subject: [PATCH] Adding holiday_region parameter to create_auto_ml_forecasting_training_job in AutoMl hook (#45465) --- .../providers/google/cloud/hooks/vertex_ai/auto_ml.py | 6 ++++++ .../providers/google/cloud/operators/vertex_ai/auto_ml.py | 3 +++ providers/tests/google/cloud/operators/test_vertex_ai.py | 5 +++++ 3 files changed, 14 insertions(+) diff --git a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py index 978ab699edacd4..a61fca8ac68e4f 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py +++ b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/auto_ml.py @@ -561,6 +561,7 @@ def create_auto_ml_forecasting_training_job( model_version_description: str | None = None, window_stride_length: int | None = None, window_max_count: int | None = None, + holiday_regions: list[str] | None = None, ) -> tuple[models.Model | None, str]: """ Create an AutoML Forecasting Training Job. @@ -717,6 +718,10 @@ def create_auto_ml_forecasting_training_job( ``window_stride_length`` rows will be used to generate a sliding window. :param window_max_count: Optional. Number of rows that should be used to generate input examples. If the total row count is larger than this number, the input data will be randomly sampled to hit the count. + :param holiday_regions: Optional. You can select one or more geographical + regions to enable holiday effect modeling. During training, Vertex AI + creates holiday categorical features within the model based on the date + from TIME_COLUMN and the specified geographical regions. """ if column_transformations: warnings.warn( @@ -774,6 +779,7 @@ def create_auto_ml_forecasting_training_job( model_version_description=model_version_description, window_stride_length=window_stride_length, window_max_count=window_max_count, + holiday_regions=holiday_regions, ) training_id = self.extract_training_id(self._job.resource_name) if model: diff --git a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py index 1ca440170698bf..d92b2cef5858d0 100644 --- a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py +++ b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/auto_ml.py @@ -145,6 +145,7 @@ def __init__( parent_model: str | None = None, window_stride_length: int | None = None, window_max_count: int | None = None, + holiday_regions: list[str] | None = None, **kwargs, ) -> None: super().__init__( @@ -184,6 +185,7 @@ def __init__( self.budget_milli_node_hours = budget_milli_node_hours self.window_stride_length = window_stride_length self.window_max_count = window_max_count + self.holiday_regions = holiday_regions def execute(self, context: Context): self.hook = AutoMLHook( @@ -236,6 +238,7 @@ def execute(self, context: Context): sync=self.sync, window_stride_length=self.window_stride_length, window_max_count=self.window_max_count, + holiday_regions=self.holiday_regions, ) if model: diff --git a/providers/tests/google/cloud/operators/test_vertex_ai.py b/providers/tests/google/cloud/operators/test_vertex_ai.py index 7b5e3ffd5acee6..6410249f10a486 100644 --- a/providers/tests/google/cloud/operators/test_vertex_ai.py +++ b/providers/tests/google/cloud/operators/test_vertex_ai.py @@ -164,6 +164,7 @@ TEST_TRAINING_FORECAST_HORIZON = 10 TEST_TRAINING_DATA_GRANULARITY_UNIT = "day" TEST_TRAINING_DATA_GRANULARITY_COUNT = 1 +TEST_TRAINING_DATA_HOLIDAY_REGIONS = ["US"] TEST_MODEL_ID = "test_model_id" TEST_MODEL_NAME = f"projects/{GCP_PROJECT}/locations/{GCP_LOCATION}/models/test_model_id" @@ -1461,6 +1462,7 @@ def test_execute(self, mock_hook, mock_dataset): region=GCP_LOCATION, project_id=GCP_PROJECT, parent_model=TEST_PARENT_MODEL, + holiday_regions=TEST_TRAINING_DATA_HOLIDAY_REGIONS, ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) @@ -1506,6 +1508,7 @@ def test_execute(self, mock_hook, mock_dataset): model_version_description=None, window_stride_length=None, window_max_count=None, + holiday_regions=TEST_TRAINING_DATA_HOLIDAY_REGIONS, ) @mock.patch("google.cloud.aiplatform.datasets.TimeSeriesDataset") @@ -1530,6 +1533,7 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da region=GCP_LOCATION, project_id=GCP_PROJECT, parent_model=VERSIONED_TEST_PARENT_MODEL, + holiday_regions=TEST_TRAINING_DATA_HOLIDAY_REGIONS, ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.return_value.create_auto_ml_forecasting_training_job.assert_called_once_with( @@ -1573,6 +1577,7 @@ def test_execute__parent_model_version_index_is_removed(self, mock_hook, mock_da model_version_description=None, window_stride_length=None, window_max_count=None, + holiday_regions=TEST_TRAINING_DATA_HOLIDAY_REGIONS, )