Skip to content

Commit

Permalink
Refactor ARIMA model params.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxjohn committed Apr 24, 2022
1 parent 210645f commit 5171554
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion hyperts/framework/search_space/macro_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def default_arima_init_kwargs(self):
'seasonal_order': Choice([(1, 0, 0), (1, 0, 1), (1, 0, 1),
(1, 0, 2), (0, 2, 1), (0, 1, 2),
(2, 0, 1), (2, 0, 2), (0, 1, 1)]),
'period_offset': Choice([0, 0, 0, 0, 0, 0, 1, -1, 2, -2]),
# 'period_offset': Choice([0, 0, 0, 0, 0, 0, 1, -1, 2, -2]),
'y_scale': Choice(['min_max']*8 + ['max_abs']*1 + ['z_scale']*1)
}

Expand Down
18 changes: 10 additions & 8 deletions hyperts/framework/wrappers/stats_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,23 @@ def fit(self, X, y=None, **kwargs):
q = self.init_kwargs.pop('q', 1)
trend = self.init_kwargs.pop('trend', 'c')
seasonal_order = self.init_kwargs.pop('seasonal_order', (1, 1, 1))
period_offset = self.init_kwargs.pop('period_offset', 0)
if period != 2 and period+period_offset > 1:
period = min(period+period_offset, 30)
else:
period = 3
# period_offset = self.init_kwargs.pop('period_offset', 0)
# if period != 2 and period+period_offset > 1:
# period = min(period+period_offset, 30)
# else:
# period = 3
period = kwargs.get('period', period)
seasonal_order = seasonal_order + (period,)
if period > 2 and period <= 12:
seasonal_order = seasonal_order + (period,)
else:
seasonal_order = (0, 0, 0, 0)

try:
model = ARIMA(endog=y, order=(p, d, q), trend=trend, freq=freq,
seasonal_order=seasonal_order, dates=X[self.timestamp])
self.model = model.fit(**self.init_kwargs)
except:
model = ARIMA(endog=y, order=(p, d, q), trend=trend, freq=freq, dates=X[self.timestamp])
self.model = model.fit(**self.init_kwargs)
self.model = model.fit(**self.init_kwargs)

def predict(self, X, **kwargs):
last_date = X[self.timestamp].tail(1).to_list()[0].to_pydatetime()
Expand Down
2 changes: 1 addition & 1 deletion hyperts/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def metric_to_scorer(metric, task, pos_label=None, **options):
(callable(metric) and metric.__name__ in const.POSLABEL_REQUIRED):
average = _task_to_average(task)
scorer._kwargs['average'] = average
if average is 'binary':
if average == 'binary':
scorer._kwargs['pos_label'] = pos_label
logger.info(f"pos_label is {pos_label}.")

Expand Down

0 comments on commit 5171554

Please sign in to comment.