diff --git a/src/gluonts/dataset/repository/_lstnet.py b/src/gluonts/dataset/repository/_lstnet.py index ba46cf03ae..ed38aa90b5 100644 --- a/src/gluonts/dataset/repository/_lstnet.py +++ b/src/gluonts/dataset/repository/_lstnet.py @@ -141,10 +141,7 @@ def generate_lstnet_dataset( pd.read_csv(ds_info.url, header=None), # type: ignore ) - assert df.shape == ( - ds_info.num_time_steps, - ds_info.num_series, - ), ( + assert df.shape == (ds_info.num_time_steps, ds_info.num_series,), ( "expected num_time_steps/num_series" f" {(ds_info.num_time_steps, ds_info.num_series)} but got {df.shape}" ) diff --git a/src/gluonts/ext/rotbaum/_preprocess.py b/src/gluonts/ext/rotbaum/_preprocess.py index 0f39095976..06b219bc80 100644 --- a/src/gluonts/ext/rotbaum/_preprocess.py +++ b/src/gluonts/ext/rotbaum/_preprocess.py @@ -452,9 +452,13 @@ def make_features(self, time_series: Dict, starting_index: int) -> List: end_index = starting_index + self.context_window_size if starting_index < 0: prefix = [None] * abs(starting_index) + time_series_window = time_series["target"] else: prefix = [] - time_series_window = time_series["target"][starting_index:end_index] + time_series_window = time_series["target"][ + starting_index:end_index + ] + only_lag_features, transform_dict = self._pre_transform( time_series_window, self.subtract_mean, self.count_nans ) @@ -477,10 +481,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List: list( chain( *[ - list(ent[0]) + list(ent[1].values()) + prefix + list(ent[0]) + list(ent[1].values()) for ent in [ self._pre_transform( - ts[starting_index:end_index], + ts if prefix else ts[starting_index:end_index], self.subtract_mean, self.count_nans, ) diff --git a/src/gluonts/nursery/spliced_binned_pareto/run_model_example.ipynb b/src/gluonts/nursery/spliced_binned_pareto/run_model_example.ipynb index 7f6b14c0b6..a48f632303 100644 --- a/src/gluonts/nursery/spliced_binned_pareto/run_model_example.ipynb +++ b/src/gluonts/nursery/spliced_binned_pareto/run_model_example.ipynb @@ -864,4 +864,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/test/ext/rotbaum/test_rotbaum_smoke.py b/test/ext/rotbaum/test_rotbaum_smoke.py index 2634644660..63988db9f6 100644 --- a/test/ext/rotbaum/test_rotbaum_smoke.py +++ b/test/ext/rotbaum/test_rotbaum_smoke.py @@ -12,10 +12,12 @@ # permissions and limitations under the License. import pytest +import numpy as np -from gluonts.ext.rotbaum import TreeEstimator +from gluonts.ext.rotbaum import TreeEstimator, TreePredictor from gluonts.testutil.dummy_datasets import make_dummy_datasets_with_features +from gluonts.dataset.common import ListDataset # TODO: Add support for categorical and dynamic features. @@ -59,3 +61,69 @@ def test_rotbaum_smoke(datasets): predictor = estimator.train(dataset_train) forecasts = list(predictor.predict(dataset_test)) assert len(forecasts) == len(dataset_test) + + +def test_short_history_item_pred(): + + prediction_length = 7 + freq = "D" + + dataset = ListDataset( + data_iter=[ + { + "start": "2017-10-11", + "item_id": "item_1", + "target": np.array( + [ + 1.0, + 9.0, + 2.0, + 0.0, + 0.0, + 1.0, + 5.0, + 3.0, + 4.0, + 2.0, + 0.0, + 0.0, + 1.0, + 6.0, + ] + ), + "feat_static_cat": np.array([0.0, 0.0], dtype=float), + "past_feat_dynamic_real": np.array( + [ + [1.0222e06 for i in range(14)], + [750.0 for i in range(14)], + ] + ), + }, + { + "start": "2017-10-11", + "item_id": "item_2", + "target": np.array([7.0, 0.0, 0.0, 23.0, 13.0]), + "feat_static_cat": np.array([0.0, 1.0], dtype=float), + "past_feat_dynamic_real": np.array( + [[0 for i in range(5)], [750.0 for i in range(5)]] + ), + }, + ], + freq=freq, + ) + + predictor = TreePredictor( + freq=freq, + prediction_length=prediction_length, + quantiles=[0.1, 0.5, 0.9], + max_n_datapts=50000, + method="QuantileRegression", + use_past_feat_dynamic_real=True, + use_feat_dynamic_real=False, + use_feat_dynamic_cat=False, + use_feat_static_real=False, + cardinality="auto", + ) + predictor = predictor.train(dataset) + forecasts = list(predictor.predict(dataset)) + assert forecasts[1].quantile(0.5).shape[0] == prediction_length