@@ -24,6 +24,7 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
2424 self .local_explanation = {}
2525 self .formatted_global_explanation = None
2626 self .formatted_local_explanation = None
27+ self .date_col = config .spec .datetime_column .name
2728
2829 def set_kwargs (self ):
2930 """
@@ -77,8 +78,8 @@ def _train_model(self, data_train, data_test, model_kwargs):
7778 alpha = model_kwargs ["lower_quantile" ],
7879 ),
7980 },
80- freq = pd .infer_freq (data_train ["Date" ].drop_duplicates ())
81- or pd .infer_freq (data_train ["Date" ].drop_duplicates ()[- 5 :]),
81+ freq = pd .infer_freq (data_train [self . date_col ].drop_duplicates ())
82+ or pd .infer_freq (data_train [self . date_col ].drop_duplicates ()[- 5 :]),
8283 target_transforms = [Differences ([12 ])],
8384 lags = model_kwargs .get (
8485 "lags" ,
@@ -108,7 +109,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
108109 data_train [self .model_columns ],
109110 static_features = model_kwargs .get ("static_features" , []),
110111 id_col = ForecastOutputColumns .SERIES ,
111- time_col = self .spec . datetime_column . name ,
112+ time_col = self .date_col ,
112113 target_col = self .spec .target_column ,
113114 fitted = True ,
114115 max_horizon = None if num_models is False else self .spec .horizon ,
@@ -173,7 +174,7 @@ def _build_model(self) -> pd.DataFrame:
173174 confidence_interval_width = self .spec .confidence_interval_width ,
174175 horizon = self .spec .horizon ,
175176 target_column = self .original_target_column ,
176- dt_column = self .spec . datetime_column . name ,
177+ dt_column = self .date_col ,
177178 )
178179 self ._train_model (data_train , data_test , model_kwargs )
179180 return self .forecast_output .get_forecast_long ()
0 commit comments