@@ -73,20 +73,35 @@ def _train_model(self, data_train, data_test, model_kwargs):
7373 alpha = model_kwargs ["lower_quantile" ],
7474 ),
7575 },
76- freq = pd .infer_freq (data_train .Date .drop_duplicates ()),
76+ freq = pd .infer_freq (data_train ["Date" ].drop_duplicates ())
77+ or pd .infer_freq (data_train ["Date" ].drop_duplicates ()[- 5 :]),
7778 target_transforms = [Differences ([12 ])],
78- lags = model_kwargs .get ("lags" , [1 , 6 , 12 ]),
79- lag_transforms = {
80- 1 : [ExpandingMean ()],
81- 12 : [RollingMean (window_size = 24 )],
82- },
79+ lags = model_kwargs .get (
80+ "lags" ,
81+ (
82+ [1 , 6 , 12 ]
83+ if len (self .datasets .get_additional_data_column_names ()) > 0
84+ else []
85+ ),
86+ ),
87+ lag_transforms = (
88+ {
89+ 1 : [ExpandingMean ()],
90+ 12 : [RollingMean (window_size = 24 )],
91+ }
92+ if len (self .datasets .get_additional_data_column_names ()) > 0
93+ else {}
94+ ),
8395 # date_features=[hour_index],
8496 )
8597
8698 num_models = model_kwargs .get ("recursive_models" , False )
8799
100+ self .model_columns = [
101+ ForecastOutputColumns .SERIES
102+ ] + data_train .select_dtypes (exclude = ["object" ]).columns .to_list ()
88103 fcst .fit (
89- data_train ,
104+ data_train [ self . model_columns ] ,
90105 static_features = model_kwargs .get ("static_features" , []),
91106 id_col = ForecastOutputColumns .SERIES ,
92107 time_col = self .spec .datetime_column .name ,
@@ -99,8 +114,10 @@ def _train_model(self, data_train, data_test, model_kwargs):
99114 h = self .spec .horizon ,
100115 X_df = pd .concat (
101116 [
102- data_test ,
103- fcst .get_missing_future (h = self .spec .horizon , X_df = data_test ),
117+ data_test [self .model_columns ],
118+ fcst .get_missing_future (
119+ h = self .spec .horizon , X_df = data_test [self .model_columns ]
120+ ),
104121 ],
105122 axis = 0 ,
106123 ignore_index = True ,
@@ -166,12 +183,16 @@ def _generate_report(self):
166183 # Section 1: Forecast Overview
167184 sec1_text = rc .Block (
168185 rc .Heading ("Forecast Overview" , level = 2 ),
169- rc .Text ("These plots show your forecast in the context of historical data." )
186+ rc .Text (
187+ "These plots show your forecast in the context of historical data."
188+ ),
170189 )
171190 sec_1 = _select_plot_list (
172191 lambda s_id : plot_series (
173192 self .datasets .get_all_data_long (include_horizon = False ),
174- pd .concat ([self .fitted_values ,self .outputs ], axis = 0 , ignore_index = True ),
193+ pd .concat (
194+ [self .fitted_values , self .outputs ], axis = 0 , ignore_index = True
195+ ),
175196 id_col = ForecastOutputColumns .SERIES ,
176197 time_col = self .spec .datetime_column .name ,
177198 target_col = self .original_target_column ,
@@ -184,7 +205,7 @@ def _generate_report(self):
184205 # Section 2: MlForecast Model Parameters
185206 sec2_text = rc .Block (
186207 rc .Heading ("MlForecast Model Parameters" , level = 2 ),
187- rc .Text ("These are the parameters used for the MlForecast model." )
208+ rc .Text ("These are the parameters used for the MlForecast model." ),
188209 )
189210
190211 blocks = [
@@ -197,9 +218,11 @@ def _generate_report(self):
197218 sec_2 = rc .Select (blocks = blocks )
198219
199220 all_sections = [sec1_text , sec_1 , sec2_text , sec_2 ]
200- model_description = rc .Text ("mlforecast is a framework to perform time series forecasting using machine learning models"
201- "with the option to scale to massive amounts of data using remote clusters."
202- "Fastest implementations of feature engineering for time series forecasting in Python."
203- "Support for exogenous variables and static covariates." )
221+ model_description = rc .Text (
222+ "mlforecast is a framework to perform time series forecasting using machine learning models"
223+ "with the option to scale to massive amounts of data using remote clusters."
224+ "Fastest implementations of feature engineering for time series forecasting in Python."
225+ "Support for exogenous variables and static covariates."
226+ )
204227
205- return model_description , all_sections
228+ return model_description , all_sections
0 commit comments